Skip to content

Commit 442cd33

Browse files
committed
adding tests for RSM and oRSM
1 parent 07f988c commit 442cd33

File tree

2 files changed

+89
-32
lines changed

2 files changed

+89
-32
lines changed

octis/models/RSM.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,57 +17,57 @@ class RSM(AbstractModel):
1717
update_with_test = False
1818

1919
def __init__(
20-
self, num_topics=50, epochs=5, btsz=100,
20+
self, num_topics=50, epochs=5, btsz=100,
2121
lr=0.01, momentum=0.1, K=1, softstart=0.001,
2222
decay=0, penalty_L1=False, penalty_local=False,
23-
epochs_per_monitor=1,
23+
epochs_per_monitor=1,
2424
monitor_ppl=False, monitor_time=False,
25+
increase_speed=0,
26+
cd_type='mfcd', train_optimizer='sgd',
27+
logdtm=False, random_state=None):
2528

26-
#persistent_cd = False, mean_field_cd = True, increase_cd = False,
27-
increase_speed = 0,
28-
cd_type='mfcd',
29-
train_optimizer='sgd',
30-
logdtm=False,
31-
random_state=None):
32-
33-
'''
29+
'''
3430
Parameters
3531
----------
3632
num_topics : number of topics
3733
epochs : number of training epochs
3834
btsz : batch size
3935
lr : learning rate
40-
momentum : momentum of momentum optimizer (applied only if train_optimizer='momentum')
41-
rms_decay : decay rate for RMSProp optimizer (applied only if train_optimizer='rmsprop')
42-
adam_decay1 : first decay rate for Adam optimizer (applied only if train_optimizer='adam')
43-
adam_decay2 : second decay rate for Adam optimizer (applied only if train_optimizer='adam')
36+
momentum : momentum of momentum optimizer
37+
(applied only if train_optimizer='momentum')
38+
rms_decay : decay rate for RMSProp optimizer
39+
(applied only if train_optimizer='rmsprop')
40+
adam_decay1 : first decay rate for Adam optimizer
41+
(applied only if train_optimizer='adam')
42+
adam_decay2 : second decay rate for Adam optimizer
43+
(applied only if train_optimizer='adam')
4444
K : number of Gibbs sampling steps when using KCD
4545
decay : penalization coefficient, default 0 (no penalization)
4646
penalty_L1 : if True uses L1 penalization, else L2 penalization
47-
penalty_local : if True uses local penalization, else global penalization
48-
softstart : initialization scale for weights (randomly drawn from N(0,1)*softstart)
49-
logdtm : if True each cell of the dtm is transformed as log(1+cell),
47+
penalty_local : if True uses local penalization,
48+
else global penalization
49+
softstart : initialization scale for weights
50+
(randomly drawn from N(0,1)*softstart)
51+
logdtm : if True each cell of the dtm is transformed as log(1+cell),
5052
otherwise the raw counts are used
5153
monitor : if True prints training information during training
5254
53-
cd_type : type of contrastive divergence to use, 'kcd', 'pcd', 'mfcd' (default) or 'gradkcd'
55+
cd_type : type of contrastive divergence to use,
56+
'kcd', 'pcd', 'mfcd' (default) or 'gradkcd' :
5457
'kcd' stands for k-step contrastive divergence
5558
'pcd' stands for persistent contrastive divergence
5659
'mfcd' stands for mean-field contrastive divergence
57-
'gradkcd' stands for gradual k-step contrastive divergence,
60+
'gradkcd' stands for gradual k-step contrastive divergence,
5861
where k increases over epochs by a factor increase_speed
5962
train_optimizer : training optimizer to use :
60-
'full' for full batch training,
63+
'full' for full batch training,
6164
'sgd' for simple stochastic gradient descent,
6265
'minibatch' for mini-batch training,
6366
'momentum' for mini-batch with momentum,
6467
'rmsprop' for RMSProp optimizer,
6568
'adam' for Adam optimizer,
6669
'adagrad' for Adagrad optimizer
6770
'''
68-
69-
70-
7171
super().__init__()
7272
self.hyperparameters = dict()
7373
self.hyperparameters["num_topics"] = num_topics
@@ -77,10 +77,7 @@ def __init__(
7777
self.hyperparameters["K"] = K
7878
self.hyperparameters["softstart"] = softstart
7979
self.hyperparameters["epochs"] = epochs
80-
#self.hyperparameters["increase_cd"] = increase_cd
8180
self.hyperparameters["increase_speed"] = increase_speed
82-
#self.hyperparameters["mean_field_cd"] = mean_field_cd
83-
#self.hyperparameters["persistent_cd"] = persistent_cd
8481
self.hyperparameters["monitor_time"] = monitor_time
8582
self.hyperparameters["monitor_ppl"] = monitor_ppl
8683
self.hyperparameters["epochs_per_monitor"] = epochs_per_monitor
@@ -96,7 +93,6 @@ def __init__(
9693
self.hyperparameters['adam_decay1'] = 0.9
9794
self.hyperparameters['adam_decay2'] = 0.999
9895

99-
10096
def info(self):
10197
"""
10298
Returns model informations
@@ -106,16 +102,12 @@ def info(self):
106102
"name": "RSM, Replicated Softmax Model",
107103
}
108104

109-
110105
def hyperparameters_info(self):
111106
"""
112107
Returns hyperparameters informations
113108
"""
114109
return defaults.RSM_hyperparameters_info
115110

116-
117-
118-
119111
def train_model(self, dataset, hyperparams=None, top_words=10):
120112
"""
121113
Train the model and return output
@@ -137,7 +129,8 @@ def train_model(self, dataset, hyperparams=None, top_words=10):
137129
hyperparams = {}
138130

139131
if self.use_partitions:
140-
train_corpus, test_corpus = dataset.get_partitioned_corpus(use_validation = False)
132+
train_corpus, test_corpus = dataset.get_partitioned_corpus(
133+
use_validation = False)
141134
else:
142135
train_corpus = dataset.get_corpus()
143136

tests/test_octis.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pytest
66

77
from octis.dataset.dataset import Dataset
8+
from octis.models.RSM import RSM
9+
from octis.models.oRSM import oRSM
810
from octis.models.LDA import LDA
911
from octis.models.LDA_tomopy import LDA_tomopy as LDATOMOTO
1012
from octis.models.ETM import ETM
@@ -574,3 +576,65 @@ def test_model_output_prodlda_not_partitioned(data_dir):
574576
assert type(output['topic-document-matrix']) == np.ndarray
575577
assert output['topic-document-matrix'].shape == (
576578
num_topics, len(dataset.get_corpus()))
579+
580+
581+
582+
583+
def test_model_output_rsm(data_dir):
584+
dataset = Dataset()
585+
dataset.load_custom_dataset_from_folder(data_dir + '/M10')
586+
num_topics = 3
587+
model = RSM(num_topics=num_topics, epochs=2)
588+
output = model.train_model(dataset)
589+
assert 'topics' in output.keys()
590+
assert 'topic-word-matrix' in output.keys()
591+
assert 'test-topic-document-matrix' in output.keys()
592+
593+
# check topics format
594+
assert type(output['topics']) == list
595+
assert len(output['topics']) == num_topics
596+
597+
# check topic-word-matrix format
598+
assert type(output['topic-word-matrix']) == np.ndarray
599+
assert output['topic-word-matrix'].shape == (num_topics, len(
600+
dataset.get_vocabulary()))
601+
602+
# check topic-document-matrix format
603+
assert type(output['topic-document-matrix']) == np.ndarray
604+
assert output['topic-document-matrix'].shape == (num_topics, len(
605+
dataset.get_partitioned_corpus()[0]))
606+
607+
# check test-topic-document-matrix format
608+
assert type(output['test-topic-document-matrix']) == np.ndarray
609+
assert output['test-topic-document-matrix'].shape == (num_topics, len(
610+
dataset.get_partitioned_corpus()[2]))
611+
612+
613+
def test_model_output_orsm(data_dir):
614+
dataset = Dataset()
615+
dataset.load_custom_dataset_from_folder(data_dir + '/M10')
616+
num_topics = 3
617+
model = oRSM(num_topics=num_topics, epochs=2)
618+
output = model.train_model(dataset)
619+
assert 'topics' in output.keys()
620+
assert 'topic-word-matrix' in output.keys()
621+
assert 'test-topic-document-matrix' in output.keys()
622+
623+
# check topics format
624+
assert type(output['topics']) == list
625+
assert len(output['topics']) == num_topics
626+
627+
# check topic-word-matrix format
628+
assert type(output['topic-word-matrix']) == np.ndarray
629+
assert output['topic-word-matrix'].shape == (num_topics, len(
630+
dataset.get_vocabulary()))
631+
632+
# check topic-document-matrix format
633+
assert type(output['topic-document-matrix']) == np.ndarray
634+
assert output['topic-document-matrix'].shape == (num_topics, len(
635+
dataset.get_partitioned_corpus()[0]))
636+
637+
# check test-topic-document-matrix format
638+
assert type(output['test-topic-document-matrix']) == np.ndarray
639+
assert output['test-topic-document-matrix'].shape == (num_topics, len(
640+
dataset.get_partitioned_corpus()[2]))

0 commit comments

Comments
 (0)