@@ -124,14 +124,23 @@ def __init__(self, params, network, data, _optimizer_dict=None):
124124 self ._validation_graph = None
125125
126126 @classmethod
127- def run_exists (cls , run_name , params_format = "json" , zip_run = True ):
127+ def run_exists (
128+ cls , run_name , path = "./" , params_format = "json" , zip_run = True
129+ ):
128130 """
129131 Check if a hyperparameter optimization checkpoint exists.
130132
131133 Returns True if it does.
132134
133135 Parameters
134136 ----------
137+ path : str
138+ Path to check for saved run.
139+
140+ zip_run : bool
141+ If True, MALA will check for a .zip file. If False,
142+ then separate files will be checked for.
143+
135144 run_name : string
136145 Name of the checkpoint.
137146
@@ -145,12 +154,14 @@ def run_exists(cls, run_name, params_format="json", zip_run=True):
145154
146155 """
147156 if zip_run is True :
148- return os .path .isfile (run_name + ".zip" )
157+ return os .path .isfile (os . path . join ( path , run_name + ".zip" ) )
149158 else :
150- network_name = run_name + ".network.pth"
151- iscaler_name = run_name + ".iscaler.pkl"
152- oscaler_name = run_name + ".oscaler.pkl"
153- param_name = run_name + ".params." + params_format
159+ network_name = os .path .join (path , run_name + ".network.pth" )
160+ iscaler_name = os .path .join (path , run_name + ".iscaler.pkl" )
161+ oscaler_name = os .path .join (path , run_name + ".oscaler.pkl" )
162+ param_name = os .path .join (
163+ path , run_name + ".params." + params_format
164+ )
154165 optimizer_name = run_name + ".optimizer.pth"
155166 return all (
156167 map (
@@ -1281,7 +1292,10 @@ def __create_training_checkpoint(self):
12811292 Follows https://pytorch.org/tutorials/recipes/recipes/saving_and_
12821293 loading_a_general_checkpoint.html to some degree.
12831294 """
1284- optimizer_name = self .parameters .checkpoint_name + ".optimizer.pth"
1295+ optimizer_name = os .path .join (
1296+ self .parameters .checkpoint_path ,
1297+ self .parameters .checkpoint_name + ".optimizer.pth" ,
1298+ )
12851299
12861300 # Next, we save all the other objects.
12871301
@@ -1306,14 +1320,11 @@ def __create_training_checkpoint(self):
13061320 torch .save (
13071321 save_dict , optimizer_name , _use_new_zipfile_serialization = False
13081322 )
1309- if self .parameters .run_name != "" :
1310- self .save_run (
1311- self .parameters .checkpoint_name ,
1312- save_runner = True ,
1313- path = self .parameters .run_name ,
1314- )
1315- else :
1316- self .save_run (self .parameters .checkpoint_name , save_runner = True )
1323+ self .save_run (
1324+ self .parameters .checkpoint_name ,
1325+ save_runner = True ,
1326+ path = self .parameters .checkpoint_path ,
1327+ )
13171328
13181329 @staticmethod
13191330 def __average_validation (val , device = "cpu" ):
0 commit comments