Skip to content

Commit 0222ff8

Browse files
Added missing descriptor interface
1 parent 38b7d4c commit 0222ff8

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

mala/network/predictor.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def predict_from_qeout(self, path_to_file, gather_ldos=False):
7878
)
7979

8080
def predict_for_atoms(
81-
self, atoms, gather_ldos=False, temperature=None, save_grads=False
81+
self, atoms, gather_ldos=False, temperature=None, save_grads=False, pass_descriptors=None
8282
):
8383
"""
8484
Get predicted LDOS for an atomic configuration.
@@ -133,24 +133,29 @@ def predict_for_atoms(
133133
self.data.target_calculator.invalidate_target()
134134

135135
# Calculate descriptors.
136-
time_before = perf_counter()
137-
snap_descriptors, local_size = (
138-
self.data.descriptor_calculator.calculate_from_atoms(
139-
atoms, self.data.grid_dimension
140-
)
141-
)
142-
printout(
143-
"Time for descriptor calculation: {:.8f}s".format(
144-
perf_counter() - time_before
145-
),
146-
min_verbosity=2,
147-
)
148136

137+
138+
if pass_descriptors is None:
139+
time_before = perf_counter()
140+
snap_descriptors, local_size = (
141+
self.data.descriptor_calculator.calculate_from_atoms(
142+
atoms, self.data.grid_dimension
143+
)
144+
)
145+
printout(
146+
"Time for descriptor calculation: {:.8f}s".format(
147+
perf_counter() - time_before
148+
),
149+
min_verbosity=2,
150+
)
151+
feature_length = self.data.descriptor_calculator.feature_size
152+
descs_with_xyz = self.data.descriptor_calculator.descriptors_contain_xyz
153+
else:
154+
snap_descriptors, local_size, feature_length, descs_with_xyz = pass_descriptors
149155
# Provide info from current snapshot to target calculator.
150156
self.data.target_calculator.read_additional_calculation_data(
151157
[atoms, self.data.grid_dimension], "atoms+grid"
152158
)
153-
feature_length = self.data.descriptor_calculator.feature_size
154159

155160
# The actual calculation of the LDOS from the descriptors depends
156161
# on whether we run in parallel or serial. In the former case,
@@ -171,7 +176,7 @@ def predict_for_atoms(
171176
return None
172177

173178
else:
174-
if self.data.descriptor_calculator.descriptors_contain_xyz:
179+
if descs_with_xyz:
175180
self.data.target_calculator.local_grid = snap_descriptors[
176181
:, 0:3
177182
].copy()
@@ -196,7 +201,8 @@ def predict_for_atoms(
196201
)
197202

198203
if get_rank() == 0:
199-
if self.data.descriptor_calculator.descriptors_contain_xyz:
204+
#if self.data.descriptor_calculator.descriptors_contain_xyz:
205+
if descs_with_xyz:
200206
snap_descriptors = snap_descriptors[:, :, :, 3:]
201207
feature_length -= 3
202208

@@ -208,9 +214,8 @@ def predict_for_atoms(
208214
if save_grads is True:
209215
self.input_data = snap_descriptors
210216
self.input_data.requires_grad = True
211-
return self._forward_snap_descriptors(
212-
snap_descriptors, save_torch_outputs=True
213-
)
217+
return self._forward_snap_descriptors(snap_descriptors,
218+
save_torch_outputs=True)
214219
else:
215220
return self._forward_snap_descriptors(snap_descriptors)
216221

0 commit comments

Comments
 (0)