Skip to content

Commit e154d2a

Browse files
On-the-fly training works for the RAM case
1 parent 13146e1 commit e154d2a

File tree

9 files changed

+113
-17
lines changed

9 files changed

+113
-17
lines changed

mala/datahandling/data_handler.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""DataHandler class that loads and scales data."""
22

33
import os
4+
import tempfile
45

56
import numpy as np
67
import torch
@@ -169,6 +170,19 @@ def clear_data(self):
169170
self.output_data_scaler.reset()
170171
super(DataHandler, self).clear_data()
171172

173+
def delete_temporary_data(self):
174+
"""
175+
Delete temporary data files.
176+
177+
These may have been created during a training or testing process
178+
when using atomic positions for on-the-fly calculation of descriptors
179+
rather than precomputed data files.
180+
"""
181+
for snapshot in self.parameters.snapshot_directories_list:
182+
if snapshot.temporary_input_file is not None:
183+
if os.path.isfile(snapshot.temporary_input_file):
184+
os.remove(snapshot.temporary_input_file)
185+
172186
# Preparing data
173187
######################
174188

@@ -595,14 +609,37 @@ def __load_data(self, function, data_type):
595609
snapshot.input_npy_directory, snapshot.input_npy_file
596610
)
597611
units = snapshot.input_units
612+
613+
# If the input for the descriptors is actually a JSON
614+
# file then we need to calculate the descriptors.
615+
if snapshot.snapshot_type == "json+numpy":
616+
snapshot.temporary_input_file = (
617+
tempfile.NamedTemporaryFile(
618+
delete=False,
619+
prefix=snapshot.input_npy_file.split(".")[0],
620+
suffix=".in.npy",
621+
dir=snapshot.input_npy_directory,
622+
).name
623+
)
624+
descriptors, grid = (
625+
self.descriptor_calculator.calculate_from_json(
626+
file
627+
)
628+
)
629+
np.save(snapshot.temporary_input_file, descriptors)
630+
file = snapshot.temporary_input_file
631+
598632
else:
599633
file = os.path.join(
600634
snapshot.output_npy_directory,
601635
snapshot.output_npy_file,
602636
)
603637
units = snapshot.output_units
604638

605-
if snapshot.snapshot_type == "numpy":
639+
if (
640+
snapshot.snapshot_type == "numpy"
641+
or snapshot.snapshot_type == "json+numpy"
642+
):
606643
calculator.read_from_numpy_file(
607644
file,
608645
units=units,

mala/datahandling/data_handler_base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,15 @@ def _check_snapshots(self, comm=None):
207207
),
208208
comm=comm,
209209
)
210+
elif snapshot.snapshot_type == "json+numpy":
211+
tmp_dimension = (
212+
self.descriptor_calculator.read_dimensions_from_json(
213+
os.path.join(
214+
snapshot.input_npy_directory,
215+
snapshot.input_npy_file,
216+
)
217+
)
218+
)
210219
else:
211220
raise Exception("Unknown snapshot file type.")
212221

@@ -235,7 +244,10 @@ def _check_snapshots(self, comm=None):
235244
snapshot.output_npy_directory,
236245
min_verbosity=1,
237246
)
238-
if snapshot.snapshot_type == "numpy":
247+
if (
248+
snapshot.snapshot_type == "numpy"
249+
or snapshot.snapshot_type == "json+numpy"
250+
):
239251
tmp_dimension = (
240252
self.target_calculator.read_dimensions_from_numpy_file(
241253
os.path.join(

mala/datahandling/snapshot.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ def __init__(
133133
self.input_dimension = None
134134
self.output_dimension = None
135135

136+
# Temporary descriptor files, which may be needed.
137+
self.temporary_input_file = None
138+
136139
@classmethod
137140
def from_json(cls, json_dict):
138141
"""

mala/descriptors/atomic_density.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def _calculate(self, outdir, **kwargs):
119119
else:
120120
return self.__calculate_python(**kwargs)
121121

122+
def _read_feature_dimension_from_json(self, json_dict):
123+
# For now, has to be adapted in the multielement case.
124+
return 4
125+
122126
def __calculate_lammps(self, outdir, **kwargs):
123127
"""Perform actual Gaussian descriptor calculation."""
124128
# For version compatibility; older lammps versions (the serial version

mala/descriptors/bispectrum.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,24 @@ def _calculate(self, outdir, **kwargs):
120120
else:
121121
return self.__calculate_python(**kwargs)
122122

123+
def _read_feature_dimension_from_json(self, json_dict):
124+
if self.parameters.descriptors_contain_xyz:
125+
return self.__get_feature_size() - 3
126+
else:
127+
return self.__get_feature_size()
128+
129+
def __get_feature_size(self):
130+
ncols0 = 3
131+
132+
# Analytical relation for fingerprint length
133+
ncoeff = (
134+
(self.parameters.bispectrum_twojmax + 2)
135+
* (self.parameters.bispectrum_twojmax + 3)
136+
* (self.parameters.bispectrum_twojmax + 4)
137+
)
138+
ncoeff = ncoeff // 24 # integer division
139+
return ncols0 + ncoeff
140+
123141
def __calculate_lammps(self, outdir, **kwargs):
124142
"""
125143
Perform bispectrum calculation using LAMMPS.
@@ -173,19 +191,7 @@ def __calculate_lammps(self, outdir, **kwargs):
173191

174192
# Do the LAMMPS calculation and clean up.
175193
lmp.file(self.parameters.lammps_compute_file)
176-
177-
# Set things not accessible from LAMMPS
178-
# First 3 cols are x, y, z, coords
179-
ncols0 = 3
180-
181-
# Analytical relation for fingerprint length
182-
ncoeff = (
183-
(self.parameters.bispectrum_twojmax + 2)
184-
* (self.parameters.bispectrum_twojmax + 3)
185-
* (self.parameters.bispectrum_twojmax + 4)
186-
)
187-
ncoeff = ncoeff // 24 # integer division
188-
self.feature_size = ncols0 + ncoeff
194+
self.feature_size = self.__get_feature_size()
189195

190196
# Extract data from LAMMPS calculation.
191197
# This is different for the parallel and the serial case.

mala/descriptors/descriptor.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Base class for all descriptor calculators."""
22

33
from abc import abstractmethod
4-
from functools import cached_property
4+
import json
55
import os
66
import tempfile
77

88
import ase
9+
from ase.cell import Cell
910
from ase.units import m
1011
from ase.neighborlist import NeighborList, NewPrimitiveNeighborList
1112
import numpy as np
@@ -375,6 +376,16 @@ def calculate_from_qe_out(
375376

376377
return self._calculate(working_directory, **kwargs)
377378

379+
def calculate_from_json(self, json_file, working_directory=".", **kwargs):
380+
if isinstance(json_file, str):
381+
json_dict = json.load(open(json_file, encoding="utf-8"))
382+
else:
383+
json_dict = json.load(json_file)
384+
self.grid_dimensions = json_dict["grid_dimensions"]
385+
self._atoms = ase.Atoms.fromdict(json_dict["atoms"])
386+
self._voxel = Cell(json_dict["voxel"]["array"])
387+
return self._calculate(working_directory, **kwargs)
388+
378389
def calculate_from_atoms(
379390
self, atoms, grid_dimensions, working_directory=".", **kwargs
380391
):
@@ -573,6 +584,16 @@ def convert_local_to_3d(self, descriptors_np):
573584
).transpose([2, 1, 0, 3])
574585
return descriptors_full, local_offset, local_reach
575586

587+
def read_dimensions_from_json(self, json_file):
588+
if isinstance(json_file, str):
589+
json_dict = json.load(open(json_file, encoding="utf-8"))
590+
else:
591+
json_dict = json.load(json_file)
592+
grid_dimensions = json_dict["grid_dimensions"] + [
593+
self._read_feature_dimension_from_json(json_dict)
594+
]
595+
return grid_dimensions
596+
576597
# Private methods
577598
#################
578599

@@ -1021,5 +1042,9 @@ def _grid_to_coord(self, gridpoint):
10211042
def _calculate(self, outdir, **kwargs):
10221043
pass
10231044

1045+
@abstractmethod
1046+
def _read_feature_dimension_from_json(self, json_dict):
1047+
pass
1048+
10241049
def _set_feature_size_from_array(self, array):
10251050
self.feature_size = np.shape(array)[-1]

mala/descriptors/minterpy_descriptors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ def backconvert_units(array, out_units):
8787
else:
8888
raise Exception("Unsupported unit for Minterpy descriptors.")
8989

90+
def _read_feature_dimension_from_json(self, json_dict):
91+
raise Exception(
92+
"This feature has not been implemented for Minterpy "
93+
"descriptors."
94+
)
95+
9096
def _calculate(self, atoms, outdir, grid_dimensions, **kwargs):
9197
# For version compatibility; older lammps versions (the serial version
9298
# we still use on some machines) have these constants as part of the

mala/network/trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,9 @@ def train_network(self):
597597
)
598598
self.final_validation_loss = vloss
599599

600+
# Cleaning up temporary data files.
601+
self.data.delete_temporary_data()
602+
600603
# Clean-up for pre-fetching lazy loading.
601604
if self.data.parameters.use_lazy_loading_prefetch:
602605
self._training_data_loaders.cleanup()

mala/targets/target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def read_additional_calculation_data(self, data, data_type=None):
653653
}
654654
self.atomic_forces_dft = None
655655
self.entropy_contribution_dft_calculation = None
656-
self.grid_dimensions = [0, 0, 0]
656+
self.grid_dimensions = json_dict["grid_dimensions"]
657657
self.atoms = None
658658

659659
for key in json_dict:

0 commit comments

Comments
 (0)