Skip to content

Commit 88ff606

Browse files
Merge pull request mala-project#637 from mala-project/redesign_lammps_compute_file
Make caching the descriptor calculation input files optional
2 parents 6fc5121 + 27bb734 commit 88ff606

File tree

4 files changed

+50
-39
lines changed

4 files changed

+50
-39
lines changed

mala/common/parameters.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,11 @@ class ParametersDescriptors(ParametersBase):
446446
mainly exists for debugging purposes. If the atomic density is instead
447447
used for model training itself, this parameter needs to be set.
448448
449-
lammps_compute_file : str
450-
Path to a LAMMPS compute file for the bispectrum descriptor
451-
calculation. MALA has its own collection of compute files which are
452-
used by default. Setting this parameter is thus not necessarys for
449+
custom_lammps_compute_file : str
450+
Path to a LAMMPS compute file for the descriptor calculation.
451+
MALA has its own collection of compute files which are
452+
used by default, i.e., when this string is empty.
453+
Setting this parameter is thus not necessarys for
453454
model training and inference, and it exists mainly for debugging
454455
purposes.
455456
@@ -476,7 +477,7 @@ def __init__(self):
476477

477478
# These affect all descriptors, at least as long all descriptors
478479
# use LAMMPS (which they currently do).
479-
self.lammps_compute_file = ""
480+
self.custom_lammps_compute_file = ""
480481
self.descriptors_contain_xyz = True
481482

482483
# TODO: I would rather handle the parallelization info automatically
@@ -590,7 +591,7 @@ def _update_mpi(self, new_mpi):
590591

591592
# There may have been a serial or parallel run before that is now
592593
# no longer valid.
593-
self.lammps_compute_file = ""
594+
self.custom_lammps_compute_file = ""
594595

595596

596597
class ParametersTargets(ParametersBase):

mala/descriptors/atomic_density.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -211,23 +211,26 @@ def __calculate_lammps(self, outdir, **kwargs):
211211

212212
# For now the file is chosen automatically, because this is used
213213
# mostly under the hood anyway.
214-
filepath = __file__.split("atomic_density")[0]
215-
if self.parameters._configuration["mpi"]:
216-
if self.parameters.use_z_splitting:
217-
self.parameters.lammps_compute_file = os.path.join(
218-
filepath, "in.ggrid.python"
219-
)
214+
if self.parameters.custom_lammps_compute_file != "":
215+
lammps_compute_file = self.parameters.custom_lammps_compute_file
216+
else:
217+
filepath = os.path.dirname(__file__)
218+
if self.parameters._configuration["mpi"]:
219+
if self.parameters.use_z_splitting:
220+
lammps_compute_file = os.path.join(
221+
filepath, "in.ggrid.python"
222+
)
223+
else:
224+
lammps_compute_file = os.path.join(
225+
filepath, "in.ggrid_defaultproc.python"
226+
)
220227
else:
221-
self.parameters.lammps_compute_file = os.path.join(
228+
lammps_compute_file = os.path.join(
222229
filepath, "in.ggrid_defaultproc.python"
223230
)
224-
else:
225-
self.parameters.lammps_compute_file = os.path.join(
226-
filepath, "in.ggrid_defaultproc.python"
227-
)
228231

229232
# Do the LAMMPS calculation and clean up.
230-
lmp.file(self.parameters.lammps_compute_file)
233+
lmp.file(lammps_compute_file)
231234

232235
# Extract the data.
233236
nrows_ggrid = extract_compute_np(

mala/descriptors/bispectrum.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,24 +229,26 @@ def __calculate_lammps(self, outdir, **kwargs):
229229

230230
# An empty string means that the user wants to use the standard input.
231231
# What that is differs depending on serial/parallel execution.
232-
if self.parameters.lammps_compute_file == "":
233-
filepath = __file__.split("bispectrum")[0]
232+
# We also have to ensure that no input files from a different
233+
# descriptor calculator gets used.
234+
if self.parameters.custom_lammps_compute_file != "":
235+
lammps_compute_file = self.parameters.custom_lammps_compute_file
236+
else:
237+
filepath = os.path.dirname(__file__)
234238
if self.parameters._configuration["mpi"]:
235239
if self.parameters.use_z_splitting:
236-
self.parameters.lammps_compute_file = os.path.join(
240+
lammps_compute_file = os.path.join(
237241
filepath, "in.bgridlocal.python"
238242
)
239243
else:
240-
self.parameters.lammps_compute_file = os.path.join(
244+
lammps_compute_file = os.path.join(
241245
filepath, "in.bgridlocal_defaultproc.python"
242246
)
243247
else:
244-
self.parameters.lammps_compute_file = os.path.join(
245-
filepath, "in.bgrid.python"
246-
)
248+
lammps_compute_file = os.path.join(filepath, "in.bgrid.python")
247249

248250
# Do the LAMMPS calculation and clean up.
249-
lmp.file(self.parameters.lammps_compute_file)
251+
lmp.file(lammps_compute_file)
250252
self.feature_size = self.__get_feature_size()
251253

252254
# Extract data from LAMMPS calculation.

mala/descriptors/minterpy_descriptors.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -186,23 +186,28 @@ def _calculate(self, atoms, outdir, grid_dimensions, **kwargs):
186186

187187
# For now the file is chosen automatically, because this is used
188188
# mostly under the hood anyway.
189-
filepath = __file__.split("minterpy")[0]
190-
if self.parameters._configuration["mpi"]:
191-
raise Exception(
192-
"Minterpy descriptors cannot be calculated "
193-
"in parallel yet."
189+
if self.parameters.custom_lammps_compute_file != "":
190+
lammps_compute_file = (
191+
self.parameters.custom_lammps_compute_file
194192
)
195-
# if self.parameters.use_z_splitting:
196-
# runfile = os.path.join(filepath, "in.ggrid.python")
197-
# else:
198-
# runfile = os.path.join(filepath, "in.ggrid_defaultproc.python")
199193
else:
200-
self.parameters.lammps_compute_file = os.path.join(
201-
filepath, "in.ggrid_defaultproc.python"
202-
)
194+
filepath = os.path.dirname(__file__)
195+
if self.parameters._configuration["mpi"]:
196+
raise Exception(
197+
"Minterpy descriptors cannot be calculated "
198+
"in parallel yet."
199+
)
200+
# if self.parameters.use_z_splitting:
201+
# runfile = os.path.join(filepath, "in.ggrid.python")
202+
# else:
203+
# runfile = os.path.join(filepath, "in.ggrid_defaultproc.python")
204+
else:
205+
lammps_compute_file = os.path.join(
206+
filepath, "in.ggrid_defaultproc.python"
207+
)
203208

204209
# Do the LAMMPS calculation and clean up.
205-
lmp.file(self.parameters.lammps_compute_file)
210+
lmp.file(lammps_compute_file)
206211

207212
# Extract the data.
208213
nrows_ggrid = extract_compute_np(

0 commit comments

Comments
 (0)