Skip to content

Commit fdff98d

Browse files
ADD: samples weight for val_dataset in .fit & validation_freq arg (#295)
* ADD: possibility to add val weights in .fit() * ADD: corresponding tests * ADD: validation freq parameter in model.fit
1 parent 0f0ce10 commit fdff98d

3 files changed

Lines changed: 93 additions & 41 deletions

File tree

choice_learn/models/base_model.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tqdm
1313

1414
import choice_learn.tf_ops as tf_ops
15+
from choice_learn.data import ChoiceDataset
1516

1617

1718
class ChoiceModel:
@@ -254,6 +255,7 @@ def fit(
254255
choice_dataset,
255256
sample_weight=None,
256257
val_dataset=None,
258+
validation_freq=1,
257259
verbose=0,
258260
):
259261
"""Train the model with a ChoiceDataset.
@@ -264,14 +266,18 @@ def fit(
264266
Input data in the form of a ChoiceDataset
265267
sample_weight : np.ndarray, optional
266268
Sample weight to apply, by default None
267-
val_dataset : ChoiceDataset, optional
269+
val_dataset : ChoiceDataset or (ChoiceDataset, samples_weight), optional
268270
Test ChoiceDataset to evaluate performances on test at each epoch, by default None
269271
verbose : int, optional
270272
print level, for debugging, by default 0
271273
epochs : int, optional
272274
Number of epochs, default is None, meaning we use self.epochs
273275
batch_size : int, optional
274276
Batch size, default is None, meaning we use self.batch_size
277+
validation_freq: int, optional
278+
Only relevant if validation data is provided. Specifies how many training epochs
279+
to run before a new validation run is performed, e.g. validation_freq=2 runs validation
280+
every 2 epochs.
275281
276282
Returns
277283
-------
@@ -411,24 +417,55 @@ def fit(
411417
)
412418

413419
# Test on val_dataset if provided
414-
if val_dataset is not None:
420+
if val_dataset is not None and ((epoch_nb + 1) % validation_freq) == 0:
415421
test_losses = []
416-
for batch_nb, (
417-
shared_features_batch,
418-
items_features_batch,
419-
available_items_batch,
420-
choices_batch,
421-
) in enumerate(val_dataset.iter_batch(shuffle=False, batch_size=batch_size)):
422+
423+
val_samples_weight = None
424+
if isinstance(val_dataset, tuple):
425+
if not len(val_dataset) == 2:
426+
raise ValueError(
427+
"""if argument val_dataset is a tuple, it should be
428+
in the form (ChoiceDataset, weights)"""
429+
)
430+
validation_dataset, val_samples_weight = val_dataset
431+
elif isinstance(val_dataset, ChoiceDataset):
432+
validation_dataset = val_dataset
433+
else:
434+
raise ValueError(
435+
"""val_dataset should be a ChoiceDataset or
436+
a tuple of (ChoiceDataset, weights)."""
437+
)
438+
439+
val_iterator = validation_dataset.iter_batch(
440+
shuffle=False, sample_weight=val_samples_weight, batch_size=batch_size
441+
)
442+
443+
for batch_nb, batch_data in enumerate(val_iterator):
444+
weight_batch = None
445+
if val_samples_weight is not None:
446+
batch_features, weight_batch = batch_data
447+
else:
448+
batch_features = batch_data
449+
450+
(
451+
shared_features_batch,
452+
items_features_batch,
453+
available_items_batch,
454+
choices_batch,
455+
) = batch_features
456+
422457
self.callbacks.on_batch_begin(batch_nb)
423458
self.callbacks.on_test_batch_begin(batch_nb)
424-
test_losses.append(
425-
self.batch_predict(
426-
shared_features_batch,
427-
items_features_batch,
428-
available_items_batch,
429-
choices_batch,
430-
)[0]["optimized_loss"]
431-
)
459+
460+
loss = self.batch_predict(
461+
shared_features_batch,
462+
items_features_batch,
463+
available_items_batch,
464+
choices_batch,
465+
sample_weight=weight_batch,
466+
)[0]["optimized_loss"]
467+
test_losses.append(loss)
468+
432469
val_logs["val_loss"].append(test_losses[-1])
433470
temps_logs = {k: tf.reduce_mean(v) for k, v in val_logs.items()}
434471
self.callbacks.on_test_batch_end(batch_nb, logs=temps_logs)

notebooks/models/simple_mnl.ipynb

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -307,24 +307,24 @@
307307
" <td>Weights_items_features_0</td>\n",
308308
" <td>-0.001533</td>\n",
309309
" <td>0.000621</td>\n",
310-
" <td>-2.469423</td>\n",
311-
" <td>1.353312e-02</td>\n",
310+
" <td>-2.469422</td>\n",
311+
" <td>1.353315e-02</td>\n",
312312
" </tr>\n",
313313
" <tr>\n",
314314
" <th>1</th>\n",
315315
" <td>Weights_items_features_1</td>\n",
316316
" <td>-0.006996</td>\n",
317317
" <td>0.001554</td>\n",
318-
" <td>-4.501964</td>\n",
319-
" <td>6.675720e-06</td>\n",
318+
" <td>-4.501969</td>\n",
319+
" <td>6.732662e-06</td>\n",
320320
" </tr>\n",
321321
" <tr>\n",
322322
" <th>2</th>\n",
323323
" <td>Intercept_0</td>\n",
324324
" <td>1.710969</td>\n",
325-
" <td>0.226741</td>\n",
326-
" <td>7.545904</td>\n",
327-
" <td>0.000000e+00</td>\n",
325+
" <td>0.226742</td>\n",
326+
" <td>7.545903</td>\n",
327+
" <td>4.485301e-14</td>\n",
328328
" </tr>\n",
329329
" <tr>\n",
330330
" <th>3</th>\n",
@@ -338,38 +338,38 @@
338338
" <th>4</th>\n",
339339
" <td>Intercept_2</td>\n",
340340
" <td>1.658846</td>\n",
341-
" <td>0.448417</td>\n",
342-
" <td>3.699342</td>\n",
343-
" <td>2.161264e-04</td>\n",
341+
" <td>0.448416</td>\n",
342+
" <td>3.699345</td>\n",
343+
" <td>2.161564e-04</td>\n",
344344
" </tr>\n",
345345
" <tr>\n",
346346
" <th>5</th>\n",
347347
" <td>Intercept_3</td>\n",
348348
" <td>1.853437</td>\n",
349-
" <td>0.361953</td>\n",
350-
" <td>5.120663</td>\n",
351-
" <td>3.576279e-07</td>\n",
349+
" <td>0.361952</td>\n",
350+
" <td>5.120667</td>\n",
351+
" <td>3.044562e-07</td>\n",
352352
" </tr>\n",
353353
" </tbody>\n",
354354
"</table>\n",
355355
"</div>"
356356
],
357357
"text/plain": [
358358
" Coefficient Name Coefficient Estimation Std. Err z_value \\\n",
359-
"0 Weights_items_features_0 -0.001533 0.000621 -2.469423 \n",
360-
"1 Weights_items_features_1 -0.006996 0.001554 -4.501964 \n",
361-
"2 Intercept_0 1.710969 0.226741 7.545904 \n",
359+
"0 Weights_items_features_0 -0.001533 0.000621 -2.469422 \n",
360+
"1 Weights_items_features_1 -0.006996 0.001554 -4.501969 \n",
361+
"2 Intercept_0 1.710969 0.226742 7.545903 \n",
362362
"3 Intercept_1 0.308263 0.206591 1.492140 \n",
363-
"4 Intercept_2 1.658846 0.448417 3.699342 \n",
364-
"5 Intercept_3 1.853437 0.361953 5.120663 \n",
363+
"4 Intercept_2 1.658846 0.448416 3.699345 \n",
364+
"5 Intercept_3 1.853437 0.361952 5.120667 \n",
365365
"\n",
366366
" P(.>z) \n",
367-
"0 1.353312e-02 \n",
368-
"1 6.675720e-06 \n",
369-
"2 0.000000e+00 \n",
367+
"0 1.353315e-02 \n",
368+
"1 6.732662e-06 \n",
369+
"2 4.485301e-14 \n",
370370
"3 1.356624e-01 \n",
371-
"4 2.161264e-04 \n",
372-
"5 3.576279e-07 "
371+
"4 2.161564e-04 \n",
372+
"5 3.044562e-07 "
373373
]
374374
},
375375
"execution_count": null,
@@ -391,7 +391,7 @@
391391
],
392392
"metadata": {
393393
"kernelspec": {
394-
"display_name": "tf_env",
394+
"display_name": "choice_learn",
395395
"language": "python",
396396
"name": "python3"
397397
},
@@ -405,7 +405,7 @@
405405
"name": "python",
406406
"nbconvert_exporter": "python",
407407
"pygments_lexer": "ipython3",
408-
"version": "3.11.4"
408+
"version": "3.12.11"
409409
}
410410
},
411411
"nbformat": 4,

tests/unit_tests/models/test_simplemnl.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,18 @@ def test_save_load():
153153

154154
assert nll_a == nll_b
155155
shutil.rmtree("test_save")
156+
157+
158+
def test_weighted_val_dataset():
159+
"""Tests instantiation with item and fit with Adam."""
160+
tf.config.run_functions_eagerly(True)
161+
model = SimpleMNL(intercept="item", optimizer="Adam", epochs=100, lr=0.1)
162+
model.instantiate(n_items=3, n_items_features=2, n_shared_features=3)
163+
nll_b = model.evaluate(test_dataset)
164+
model.fit(
165+
test_dataset, get_report=True, val_dataset=(test_dataset, np.ones((len(test_dataset),)))
166+
)
167+
nll_a = model.evaluate(test_dataset, batch_size=-1)
168+
assert nll_a < nll_b
169+
170+
assert model.report.to_numpy().shape == (7, 5)

0 commit comments

Comments
 (0)