@@ -22,7 +22,7 @@ def __init__(
2222 decay = 0 , penalty_L1 = False , penalty_local = False ,
2323 epochs_per_monitor = 1 ,
2424 monitor_ppl = False , monitor_time = False ,
25- increase_speed = 0 ,
25+ increase_speed = 0 , rms_decay = 0.9 , adam_decay1 = 0.9 , adam_decay2 = 0.999 ,
2626 cd_type = 'mfcd' , train_optimizer = 'sgd' ,
2727 logdtm = False , random_state = None ):
2828
@@ -89,9 +89,9 @@ def __init__(
8989 self .hyperparameters ["logdtm" ] = logdtm
9090 self .hyperparameters ["val_dtm" ] = None
9191 self .hyperparameters ["train_optimizer" ] = train_optimizer
92- self .hyperparameters ['rms_decay' ] = 0.9
93- self .hyperparameters ['adam_decay1' ] = 0.9
94- self .hyperparameters ['adam_decay2' ] = 0.999
92+ self .hyperparameters ['rms_decay' ] = rms_decay
93+ self .hyperparameters ['adam_decay1' ] = adam_decay1
94+ self .hyperparameters ['adam_decay2' ] = adam_decay2
9595
9696 def info (self ):
9797 """
@@ -139,15 +139,15 @@ def train_model(self, dataset, hyperparams=None, top_words=10):
139139
140140 if self .use_partitions :
141141 print ("Building train DTM..." )
142- train_dtm = self .build_dtm (train_corpus , self .id2word )
142+ self . train_dtm = self .build_dtm (train_corpus , self .id2word )
143143 print ("Building test DTM..." )
144- test_dtm = self .build_dtm (test_corpus , self .id2word )
145- hyperparams ["dtm" ] = train_dtm
146- hyperparams ["val_dtm" ] = test_dtm
144+ self . test_dtm = self .build_dtm (test_corpus , self .id2word )
145+ hyperparams ["dtm" ] = self . train_dtm
146+ hyperparams ["val_dtm" ] = self . test_dtm
147147 else :
148148 print ("Building DTM..." )
149- train_dtm = self .build_dtm (train_corpus , self .id2word )
150- hyperparams ["dtm" ] = train_dtm
149+ self . train_dtm = self .build_dtm (train_corpus , self .id2word )
150+ hyperparams ["dtm" ] = self . train_dtm
151151 hyperparams ["val_dtm" ] = None
152152
153153 if "num_topics" not in hyperparams :
@@ -156,31 +156,44 @@ def train_model(self, dataset, hyperparams=None, top_words=10):
156156 self .hyperparameters .update (hyperparams )
157157
158158 self .trained_model = self .RSM_model ()
159+ self .trained_model .id2word = self .id2word
159160 self .trained_model .train (** self .hyperparameters )
160161
162+ return self .get_model_output (top_words )
163+
164+ # result = {}
165+
166+ # result["topic-word-matrix"] = self.trained_model._get_topic_word_matrix()
167+
168+ # if top_words > 0:
169+ # result['topics'] = self.trained_model._get_topics(top_words)
170+
171+ # result["topic-document-matrix"] = self.trained_model._get_topic_doc(self.train_dtm)
172+
173+ # if self.use_partitions:
174+ # result["test-topic-document-matrix"] = self.trained_model._get_topic_doc(self.test_dtm)
175+ # else:
176+ # result["test-topic-document-matrix"] = result["topic-document-matrix"]
177+
178+ # return result
179+
180+
181+ def get_model_output (self , top_words = 10 ):
161182 result = {}
162183
163184 result ["topic-word-matrix" ] = self .trained_model ._get_topic_word_matrix ()
164185
165186 if top_words > 0 :
166- topics_output = []
167- for topic in result ["topic-word-matrix" ]:
168- top_k = np .argsort (topic )[- top_words :]
169- top_k_words = list (reversed ([self .id2word [i ] for i in top_k ]))
170- topics_output .append (top_k_words )
171- result ["topics" ] = topics_output
172-
173- #result["topics"] = self.trained_model.topic_words(topk=top_words, id2word=self.id2word)
187+ result ['topics' ] = self .trained_model ._get_topics (top_words )
174188
175- result ["topic-document-matrix" ] = self .trained_model .visible2hidden ( train_dtm ). T
189+ result ["topic-document-matrix" ] = self .trained_model ._get_topic_doc ( self . train_dtm )
176190
177191 if self .use_partitions :
178- result ["test-topic-document-matrix" ] = self .trained_model .visible2hidden ( test_dtm ). T
192+ result ["test-topic-document-matrix" ] = self .trained_model ._get_topic_doc ( self . test_dtm )
179193 else :
180194 result ["test-topic-document-matrix" ] = result ["topic-document-matrix" ]
181195
182- return result
183-
196+ return result
184197
185198
186199############### preprocessing functions
@@ -323,6 +336,7 @@ def topic_words(self, topk, id2word=None):
323336
324337
325338
339+
326340 def _get_topic_word_matrix (self ):
327341 """
328342 Return the topic representation of the words
@@ -349,6 +363,29 @@ def _get_topic_word_matrix0(self):
349363 return topic_word_matrix
350364
351365
366+ def _get_topic_doc (self , dtm ):
367+ return self .visible2hidden (dtm ).T
368+
369+
370+ def _get_topics (self , topk ):
371+ w_vh , w_v , w_h = self .W
372+ T = self .hidden
373+ words = np .array ([k for k in self .id2word .token2id .keys ()])
374+
375+ toplist = []
376+ for t in range (T ):
377+ topw = w_vh [: , t ]
378+ bestwords = words [np .argsort (topw )[::- 1 ]][0 :topk ]
379+ toplist .append (bestwords )
380+
381+ return toplist
382+
383+ # topics_output = []
384+ # for topic in result["topic-word-matrix"]:
385+ # top_k = np.argsort(topic)[-top_words:]
386+ # top_k_words = list(reversed([self.id2word[i] for i in top_k]))
387+ # topics_output.append(top_k_words)
388+
352389
353390 ##################################### leapfrog trainsition operators
354391
@@ -609,13 +646,9 @@ def train(self, dtm, num_topics=5, epochs=3, btsz=100,
609646 lr = 0.01 , momentum = 0.5 , K = 1 , decay = 0 , penalty_L1 = False , penalty_local = False ,
610647 epochs_per_monitor = 1 , monitor_time = False , monitor_ppl = False ,
611648 train_optimizer = 'sgd' , cd_type = 'mfcd' , logdtm = False ,
612-
613- # persistent_cd = False, mean_field_cd = False, increase_cd = False,
614649 rms_decay = 0.9 , adam_decay1 = 0.9 , adam_decay2 = 0.999 ,
615-
616- increase_speed = 0 ,
617-
618- softstart = 0.001 , initw = None , val_dtm = None , random_state = None ):
650+ increase_speed = 0 , softstart = 0.001 ,
651+ initw = None , val_dtm = None , random_state = None ):
619652
620653 self .train_dtm = dtm
621654 hidden = num_topics
0 commit comments