@@ -78,7 +78,7 @@ def predict_from_qeout(self, path_to_file, gather_ldos=False):
7878 )
7979
8080 def predict_for_atoms (
81- self , atoms , gather_ldos = False , temperature = None , save_grads = False
81+ self , atoms , gather_ldos = False , temperature = None , save_grads = False , pass_descriptors = None
8282 ):
8383 """
8484 Get predicted LDOS for an atomic configuration.
@@ -133,24 +133,29 @@ def predict_for_atoms(
133133 self .data .target_calculator .invalidate_target ()
134134
135135 # Calculate descriptors.
136- time_before = perf_counter ()
137- snap_descriptors , local_size = (
138- self .data .descriptor_calculator .calculate_from_atoms (
139- atoms , self .data .grid_dimension
140- )
141- )
142- printout (
143- "Time for descriptor calculation: {:.8f}s" .format (
144- perf_counter () - time_before
145- ),
146- min_verbosity = 2 ,
147- )
148136
137+
138+ if pass_descriptors is None :
139+ time_before = perf_counter ()
140+ snap_descriptors , local_size = (
141+ self .data .descriptor_calculator .calculate_from_atoms (
142+ atoms , self .data .grid_dimension
143+ )
144+ )
145+ printout (
146+ "Time for descriptor calculation: {:.8f}s" .format (
147+ perf_counter () - time_before
148+ ),
149+ min_verbosity = 2 ,
150+ )
151+ feature_length = self .data .descriptor_calculator .feature_size
152+ descs_with_xyz = self .data .descriptor_calculator .descriptors_contain_xyz
153+ else :
154+ snap_descriptors , local_size , feature_length , descs_with_xyz = pass_descriptors
149155 # Provide info from current snapshot to target calculator.
150156 self .data .target_calculator .read_additional_calculation_data (
151157 [atoms , self .data .grid_dimension ], "atoms+grid"
152158 )
153- feature_length = self .data .descriptor_calculator .feature_size
154159
155160 # The actual calculation of the LDOS from the descriptors depends
156161 # on whether we run in parallel or serial. In the former case,
@@ -171,7 +176,7 @@ def predict_for_atoms(
171176 return None
172177
173178 else :
174- if self . data . descriptor_calculator . descriptors_contain_xyz :
179+ if descs_with_xyz :
175180 self .data .target_calculator .local_grid = snap_descriptors [
176181 :, 0 :3
177182 ].copy ()
@@ -196,7 +201,8 @@ def predict_for_atoms(
196201 )
197202
198203 if get_rank () == 0 :
199- if self .data .descriptor_calculator .descriptors_contain_xyz :
204+ #if self.data.descriptor_calculator.descriptors_contain_xyz:
205+ if descs_with_xyz :
200206 snap_descriptors = snap_descriptors [:, :, :, 3 :]
201207 feature_length -= 3
202208
@@ -208,9 +214,8 @@ def predict_for_atoms(
208214 if save_grads is True :
209215 self .input_data = snap_descriptors
210216 self .input_data .requires_grad = True
211- return self ._forward_snap_descriptors (
212- snap_descriptors , save_torch_outputs = True
213- )
217+ return self ._forward_snap_descriptors (snap_descriptors ,
218+ save_torch_outputs = True )
214219 else :
215220 return self ._forward_snap_descriptors (snap_descriptors )
216221
0 commit comments