Skip to content

Commit f8588f1

Browse files
Implemented checkpointing in a different folder
1 parent 4081813 commit f8588f1

File tree

3 files changed

+33
-17
lines changed

3 files changed

+33
-17
lines changed

examples/advanced/ex01_checkpoint_training.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def initial_setup():
3232
# as "ex07".
3333
parameters.running.checkpoints_each_epoch = 5
3434
parameters.running.checkpoint_name = "ex01_checkpoint"
35+
parameters.running.checkpoint_path = "./"
3536

3637
data_handler = mala.DataHandler(parameters)
3738
data_handler.add_snapshot(
@@ -62,9 +63,9 @@ def initial_setup():
6263
return parameters, test_network, data_handler, test_trainer
6364

6465

65-
if mala.Trainer.run_exists("ex01_checkpoint"):
66+
if mala.Trainer.run_exists("ex01_checkpoint", path="./"):
6667
parameters, network, datahandler, trainer = mala.Trainer.load_run(
67-
"ex01_checkpoint"
68+
"ex01_checkpoint", path="./"
6869
)
6970
printout("Starting resumed training.")
7071
else:

mala/common/parameters.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,9 @@ class ParametersRunning(ParametersBase):
881881
Name used for the checkpoints. Using this, multiple runs
882882
can be performed in the same directory.
883883
884+
checkpoint_path : string
885+
Path where the checkpoints will be saved (and loaded from)
886+
884887
run_name : string
885888
Name of the run used for logging.
886889
@@ -972,6 +975,7 @@ def __init__(self):
972975
self.checkpoints_each_epoch = 0
973976
# self.checkpoint_best_so_far = False
974977
self.checkpoint_name = "checkpoint_mala"
978+
self.checkpoint_path = "./"
975979
self.run_name = ""
976980
self.logging_dir = "./mala_logging"
977981
self.logging_dir_append_date = True

mala/network/trainer.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,23 @@ def __init__(self, params, network, data, _optimizer_dict=None):
124124
self._validation_graph = None
125125

126126
@classmethod
127-
def run_exists(cls, run_name, params_format="json", zip_run=True):
127+
def run_exists(
128+
cls, run_name, path="./", params_format="json", zip_run=True
129+
):
128130
"""
129131
Check if a hyperparameter optimization checkpoint exists.
130132
131133
Returns True if it does.
132134
133135
Parameters
134136
----------
137+
path : str
138+
Path to check for saved run.
139+
140+
zip_run : bool
141+
If True, MALA will check for a .zip file. If False,
142+
then separate files will be checked for.
143+
135144
run_name : string
136145
Name of the checkpoint.
137146
@@ -145,12 +154,14 @@ def run_exists(cls, run_name, params_format="json", zip_run=True):
145154
146155
"""
147156
if zip_run is True:
148-
return os.path.isfile(run_name + ".zip")
157+
return os.path.isfile(os.path.join(path, run_name + ".zip"))
149158
else:
150-
network_name = run_name + ".network.pth"
151-
iscaler_name = run_name + ".iscaler.pkl"
152-
oscaler_name = run_name + ".oscaler.pkl"
153-
param_name = run_name + ".params." + params_format
159+
network_name = os.path.join(path, run_name + ".network.pth")
160+
iscaler_name = os.path.join(path, run_name + ".iscaler.pkl")
161+
oscaler_name = os.path.join(path, run_name + ".oscaler.pkl")
162+
param_name = os.path.join(
163+
path, run_name + ".params." + params_format
164+
)
154165
optimizer_name = run_name + ".optimizer.pth"
155166
return all(
156167
map(
@@ -1281,7 +1292,10 @@ def __create_training_checkpoint(self):
12811292
Follows https://pytorch.org/tutorials/recipes/recipes/saving_and_
12821293
loading_a_general_checkpoint.html to some degree.
12831294
"""
1284-
optimizer_name = self.parameters.checkpoint_name + ".optimizer.pth"
1295+
optimizer_name = os.path.join(
1296+
self.parameters.checkpoint_path,
1297+
self.parameters.checkpoint_name + ".optimizer.pth",
1298+
)
12851299

12861300
# Next, we save all the other objects.
12871301

@@ -1306,14 +1320,11 @@ def __create_training_checkpoint(self):
13061320
torch.save(
13071321
save_dict, optimizer_name, _use_new_zipfile_serialization=False
13081322
)
1309-
if self.parameters.run_name != "":
1310-
self.save_run(
1311-
self.parameters.checkpoint_name,
1312-
save_runner=True,
1313-
path=self.parameters.run_name,
1314-
)
1315-
else:
1316-
self.save_run(self.parameters.checkpoint_name, save_runner=True)
1323+
self.save_run(
1324+
self.parameters.checkpoint_name,
1325+
save_runner=True,
1326+
path=self.parameters.checkpoint_path,
1327+
)
13171328

13181329
@staticmethod
13191330
def __average_validation(val, device="cpu"):

0 commit comments

Comments
 (0)