Skip to content

Commit bae76ec

Browse files
committed
adding training flexibility to rsm
1 parent 38c7fc5 commit bae76ec

File tree

1 file changed

+61
-28
lines changed

1 file changed

+61
-28
lines changed

octis/models/RSM.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
decay=0, penalty_L1=False, penalty_local=False,
2323
epochs_per_monitor=1,
2424
monitor_ppl=False, monitor_time=False,
25-
increase_speed=0,
25+
increase_speed=0, rms_decay=0.9, adam_decay1=0.9, adam_decay2=0.999,
2626
cd_type='mfcd', train_optimizer='sgd',
2727
logdtm=False, random_state=None):
2828

@@ -89,9 +89,9 @@ def __init__(
8989
self.hyperparameters["logdtm"] = logdtm
9090
self.hyperparameters["val_dtm"] = None
9191
self.hyperparameters["train_optimizer"] = train_optimizer
92-
self.hyperparameters['rms_decay'] = 0.9
93-
self.hyperparameters['adam_decay1'] = 0.9
94-
self.hyperparameters['adam_decay2'] = 0.999
92+
self.hyperparameters['rms_decay'] = rms_decay
93+
self.hyperparameters['adam_decay1'] = adam_decay1
94+
self.hyperparameters['adam_decay2'] = adam_decay2
9595

9696
def info(self):
9797
"""
@@ -139,15 +139,15 @@ def train_model(self, dataset, hyperparams=None, top_words=10):
139139

140140
if self.use_partitions:
141141
print("Building train DTM...")
142-
train_dtm = self.build_dtm(train_corpus, self.id2word)
142+
self.train_dtm = self.build_dtm(train_corpus, self.id2word)
143143
print("Building test DTM...")
144-
test_dtm = self.build_dtm(test_corpus, self.id2word)
145-
hyperparams["dtm"] = train_dtm
146-
hyperparams["val_dtm"] = test_dtm
144+
self.test_dtm = self.build_dtm(test_corpus, self.id2word)
145+
hyperparams["dtm"] = self.train_dtm
146+
hyperparams["val_dtm"] = self.test_dtm
147147
else:
148148
print("Building DTM...")
149-
train_dtm = self.build_dtm(train_corpus, self.id2word)
150-
hyperparams["dtm"] = train_dtm
149+
self.train_dtm = self.build_dtm(train_corpus, self.id2word)
150+
hyperparams["dtm"] = self.train_dtm
151151
hyperparams["val_dtm"] = None
152152

153153
if "num_topics" not in hyperparams:
@@ -156,31 +156,44 @@ def train_model(self, dataset, hyperparams=None, top_words=10):
156156
self.hyperparameters.update(hyperparams)
157157

158158
self.trained_model = self.RSM_model()
159+
self.trained_model.id2word = self.id2word
159160
self.trained_model.train(**self.hyperparameters)
160161

162+
return self.get_model_output(top_words)
163+
164+
# result = {}
165+
166+
# result["topic-word-matrix"] = self.trained_model._get_topic_word_matrix()
167+
168+
# if top_words > 0:
169+
# result['topics'] = self.trained_model._get_topics(top_words)
170+
171+
# result["topic-document-matrix"] = self.trained_model._get_topic_doc(self.train_dtm)
172+
173+
# if self.use_partitions:
174+
# result["test-topic-document-matrix"] = self.trained_model._get_topic_doc(self.test_dtm)
175+
# else:
176+
# result["test-topic-document-matrix"] = result["topic-document-matrix"]
177+
178+
# return result
179+
180+
181+
def get_model_output(self, top_words=10):
161182
result = {}
162183

163184
result["topic-word-matrix"] = self.trained_model._get_topic_word_matrix()
164185

165186
if top_words > 0:
166-
topics_output = []
167-
for topic in result["topic-word-matrix"]:
168-
top_k = np.argsort(topic)[-top_words:]
169-
top_k_words = list(reversed([self.id2word[i] for i in top_k]))
170-
topics_output.append(top_k_words)
171-
result["topics"] = topics_output
172-
173-
#result["topics"] = self.trained_model.topic_words(topk=top_words, id2word=self.id2word)
187+
result['topics'] = self.trained_model._get_topics(top_words)
174188

175-
result["topic-document-matrix"] = self.trained_model.visible2hidden(train_dtm).T
189+
result["topic-document-matrix"] = self.trained_model._get_topic_doc(self.train_dtm)
176190

177191
if self.use_partitions:
178-
result["test-topic-document-matrix"] = self.trained_model.visible2hidden(test_dtm).T
192+
result["test-topic-document-matrix"] = self.trained_model._get_topic_doc(self.test_dtm)
179193
else:
180194
result["test-topic-document-matrix"] = result["topic-document-matrix"]
181195

182-
return result
183-
196+
return result
184197

185198

186199
############### preprocessing functions
@@ -323,6 +336,7 @@ def topic_words(self, topk, id2word=None):
323336

324337

325338

339+
326340
def _get_topic_word_matrix(self):
327341
"""
328342
Return the topic representation of the words
@@ -349,6 +363,29 @@ def _get_topic_word_matrix0(self):
349363
return topic_word_matrix
350364

351365

366+
def _get_topic_doc(self, dtm):
367+
return self.visible2hidden(dtm).T
368+
369+
370+
def _get_topics(self, topk):
371+
w_vh, w_v, w_h = self.W
372+
T = self.hidden
373+
words = np.array([k for k in self.id2word.token2id.keys()])
374+
375+
toplist = []
376+
for t in range(T):
377+
topw = w_vh[: , t]
378+
bestwords = words[np.argsort(topw)[::-1]][0:topk]
379+
toplist.append(bestwords)
380+
381+
return toplist
382+
383+
# topics_output = []
384+
# for topic in result["topic-word-matrix"]:
385+
# top_k = np.argsort(topic)[-top_words:]
386+
# top_k_words = list(reversed([self.id2word[i] for i in top_k]))
387+
# topics_output.append(top_k_words)
388+
352389

353390
##################################### leapfrog trainsition operators
354391

@@ -609,13 +646,9 @@ def train(self, dtm, num_topics=5, epochs=3, btsz=100,
609646
lr=0.01, momentum=0.5, K=1, decay=0, penalty_L1=False, penalty_local=False,
610647
epochs_per_monitor=1, monitor_time = False, monitor_ppl = False,
611648
train_optimizer='sgd', cd_type='mfcd', logdtm=False,
612-
613-
# persistent_cd = False, mean_field_cd = False, increase_cd = False,
614649
rms_decay=0.9, adam_decay1=0.9, adam_decay2=0.999,
615-
616-
increase_speed = 0,
617-
618-
softstart=0.001, initw=None, val_dtm=None, random_state=None):
650+
increase_speed = 0, softstart=0.001,
651+
initw=None, val_dtm=None, random_state=None):
619652

620653
self.train_dtm = dtm
621654
hidden = num_topics

0 commit comments

Comments
 (0)