@@ -188,9 +188,13 @@ def to_inferencedata(self, **kwargs) -> InferenceData:
188188 posterior = collections .defaultdict (list )
189189 sample_stats = collections .defaultdict (list )
190190 for c , chain in enumerate (chains ):
191+ # Every retrieved array is shortened to the previously determined chain length.
192+ # This is needed for database backends which may get inserts inbetween.
193+ clen = chain_lengths [chain .cid ]
194+
191195 # Obtain a mask by which draws can be split into warmup/posterior
192196 if "tune" in chain .sample_stats :
193- tune = chain .get_stats ("tune" ).astype (bool )
197+ tune = chain .get_stats ("tune" )[: clen ] .astype (bool )
194198 else :
195199 if c == 0 :
196200 _log .warning (
@@ -200,12 +204,12 @@ def to_inferencedata(self, **kwargs) -> InferenceData:
200204
201205 # Split all variables draws into warmup/posterior
202206 for var in variables :
203- draws = chain .get_draws (var .name )
207+ draws = chain .get_draws (var .name )[: clen ]
204208 warmup_posterior [var .name ].append (draws [tune ])
205209 posterior [var .name ].append (draws [~ tune ])
206210 # Same for sample stats
207211 for svar in self .meta .sample_stats :
208- stats = chain .get_stats (svar .name )
212+ stats = chain .get_stats (svar .name )[: clen ]
209213 warmup_sample_stats [svar .name ].append (stats [tune ])
210214 sample_stats [svar .name ].append (stats [~ tune ])
211215
0 commit comments