Skip to content

Commit ccf51de

Browse files
Quick and dirty implementation of directly saving the model during hyperparameter optimization
1 parent cd8c858 commit ccf51de

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

mala/network/hyper_opt_optuna.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,11 @@ def __create_checkpointing(self, study, trial):
396396
"checkpoint for it.",
397397
min_verbosity=0,
398398
)
399+
# TODO: This is just a quick and dirty to save the model.
400+
if self.params.hyperparameters.number_training_per_trial == 1:
401+
self.objective._last_trainer.save_run(
402+
self.params.hyperparameters.checkpoint_name + "_best"
403+
)
399404

400405
if need_to_checkpoint is True:
401406
# We need to create a checkpoint!

mala/network/objective_base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self, params, data_handler):
6767
).count(True)
6868

6969
self._trial_type = self.params.hyperparameters.hyper_opt_method
70+
self._last_trainer = None
7071

7172
def __call__(self, trial):
7273
"""
@@ -93,11 +94,13 @@ def __call__(self, trial):
9394
0, self.params.hyperparameters.number_training_per_trial
9495
):
9596
test_network = Network(self.params)
96-
test_trainer = Trainer(
97+
self._last_trainer = Trainer(
9798
self.params, test_network, self._data_handler
9899
)
99-
test_trainer.train_network()
100-
final_validation_loss.append(test_trainer.final_validation_loss)
100+
self._last_trainer.train_network()
101+
final_validation_loss.append(
102+
self._last_trainer.final_validation_loss
103+
)
101104
if (
102105
self._trial_type == "optuna"
103106
and self.params.hyperparameters.pruner == "multi_training"
@@ -107,7 +110,7 @@ def __call__(self, trial):
107110
# meant for values DURING training, but we instead
108111
# use it for one of the losses during multiple trainings.
109112
# It should not pose a problem though.
110-
trial.report(test_trainer.final_validation_loss, i)
113+
trial.report(self._last_trainer.final_validation_loss, i)
111114
if trial.should_prune():
112115
raise TrialPruned()
113116

0 commit comments

Comments
 (0)