Skip to content

Commit f8bc86c

Browse files
Fixed issue with hyperparameterm optimozation
1 parent 08400af commit f8bc86c

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

examples/advanced/ex06_distributed_hyperparameter_optimization.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,20 @@
3636
parameters.hyperparameters.study_name = "ex06"
3737
parameters.hyperparameters.rdb_storage = "sqlite:///ex06.db"
3838

39-
# Hyperparameter optimization can be further refined by using ensemble training
40-
# at each step and by using a different metric then the validation loss
41-
# (e.g. the band energy). It is recommended not to use the ensemble training
42-
# method in Single-GPU use, as it naturally drastically decreases performance.
4339
parameters.targets.ldos_gridsize = 11
4440
parameters.targets.ldos_gridspacing_ev = 2.5
4541
parameters.targets.ldos_gridoffset_ev = -5
4642
parameters.hyperparameters.number_training_per_trial = 3
47-
parameters.running.after_training_metric = "band_energy"
43+
44+
# Hyperparameter optimization can be further refined by using ensemble training
45+
# at each step and by using a different metric then the validation loss
46+
# (e.g. the band energy). It is recommended not to use the ensemble training
47+
# method in Single-GPU use, as it naturally drastically decreases performance.
48+
# For this small example setting, using the band energy as the after training
49+
# metric is not recommended, since the small data size makes
50+
# an accurate hyperparameter search difficult. For larger systems, enabling
51+
# this option is recommended.
52+
# parameters.running.after_training_metric = "band_energy"
4853

4954
data_handler = mala.DataHandler(parameters)
5055

mala/network/trainer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def _validate_network(self, data_set_fractions, metrics):
629629
The data set fractions to validate on. Can be "train" or
630630
"validation". "test" is not supported.
631631
632-
metrics : list
632+
metrics : List
633633
List of metrics to calculate. Can be "ldos", or a number of
634634
error metrics (see Tester class). Most common apart from "ldos"
635635
is arguably "band_energy".
@@ -715,10 +715,7 @@ def _validate_network(self, data_set_fractions, metrics):
715715
# case for, e.g., distributed network trainings), we can
716716
# use a faster (or at least better parallelizing) code
717717

718-
if (
719-
len(self.parameters.validation_metrics) == 1
720-
and self.parameters.validation_metrics[0] == "ldos"
721-
):
718+
if len(metrics) == 1 and metrics[0] == "ldos":
722719

723720
errors[data_set_type]["ldos"] = (
724721
self.__calculate_validation_error_ldos_only(

0 commit comments

Comments
 (0)