@@ -67,6 +67,7 @@ def __init__(self, params, data_handler):
6767 ).count (True )
6868
6969 self ._trial_type = self .params .hyperparameters .hyper_opt_method
70+ self ._last_trainer = None
7071
7172 def __call__ (self , trial ):
7273 """
@@ -93,11 +94,13 @@ def __call__(self, trial):
9394 0 , self .params .hyperparameters .number_training_per_trial
9495 ):
9596 test_network = Network (self .params )
96- test_trainer = Trainer (
97+ self . _last_trainer = Trainer (
9798 self .params , test_network , self ._data_handler
9899 )
99- test_trainer .train_network ()
100- final_validation_loss .append (test_trainer .final_validation_loss )
100+ self ._last_trainer .train_network ()
101+ final_validation_loss .append (
102+ self ._last_trainer .final_validation_loss
103+ )
101104 if (
102105 self ._trial_type == "optuna"
103106 and self .params .hyperparameters .pruner == "multi_training"
@@ -107,7 +110,7 @@ def __call__(self, trial):
107110 # meant for values DURING training, but we instead
108111 # use it for one of the losses during multiple trainings.
109112 # It should not pose a problem though.
110- trial .report (test_trainer .final_validation_loss , i )
113+ trial .report (self . _last_trainer .final_validation_loss , i )
111114 if trial .should_prune ():
112115 raise TrialPruned ()
113116
0 commit comments