@@ -256,6 +256,60 @@ def test_insert_draw(self):
256256 chain ._get_row_at (2 , var_names = ["v1" ])
257257 pass
258258
259+ def test_to_inferencedata_equalize_chain_lengths (self , caplog ):
260+ run , chains = fully_initialized (
261+ self .backend ,
262+ make_runmeta (
263+ variables = [
264+ Variable ("A" , "uint16" , []),
265+ ],
266+ sample_stats = [Variable ("tune" , "bool" )],
267+ data = [],
268+ ),
269+ nchains = 2 ,
270+ )
271+ # Create chains of uneven lengths:
272+ # - Chain 0 has 5 tune and 15 draws (length 20)
273+ # - Chain 1 has 5 tune and 14 draws (length 19)
274+ # This simulates the situation where chains aren't synchronized.
275+ ntune = 5
276+
277+ c0 = chains [0 ]
278+ for i in range (0 , 20 ):
279+ c0 .append (dict (A = i ), stats = dict (tune = i < ntune ))
280+
281+ c1 = chains [1 ]
282+ for i in range (0 , 19 ):
283+ c1 .append (dict (A = i ), stats = dict (tune = i < ntune ))
284+
285+ assert len (c0 ) == 20
286+ assert len (c1 ) == 19
287+
288+ # With equalize=True all chains should have the length of the shortest (here: 7)
289+ # But the first 3 are tuning, so 4 posterior draws remain.
290+ with caplog .at_level (logging .WARNING ):
291+ idata_even = run .to_inferencedata (equalize_chain_lengths = True )
292+ assert "Chains vary in length" in caplog .records [0 ].message
293+ assert "Truncating to" in caplog .records [0 ].message
294+ assert len (idata_even .posterior .draw ) == 14
295+
296+ # With equalize=False the "draw" dim has the length of the longest chain (here: 8-3 = 5)
297+ caplog .clear ()
298+ with caplog .at_level (logging .WARNING ):
299+ idata_uneven = run .to_inferencedata (equalize_chain_lengths = False )
300+ # These are the messed-up chain and draw dimensions!
301+ assert idata_uneven .posterior .dims ["chain" ] == 1
302+ assert idata_uneven .posterior .dims ["draw" ] == 2
303+ # The "draws" are actually the chains, but in a weird scalar object-array?!
304+ # Doing .tolist() seems to be the only way to get our hands on it.
305+ d1 = idata_uneven .posterior .A .sel (chain = 0 , draw = 0 ).values .tolist ()
306+ d2 = idata_uneven .posterior .A .sel (chain = 0 , draw = 1 ).values .tolist ()
307+ numpy .testing .assert_array_equal (d1 , list (range (ntune , 20 )))
308+ numpy .testing .assert_array_equal (d2 , list (range (ntune , 19 )))
309+ assert "Chains vary in length" in caplog .records [0 ].message
310+ assert "see ArviZ issue #2094" in caplog .records [0 ].message
311+ pass
312+
259313
260314if __name__ == "__main__" :
261315 tc = TestClickHouseBackend ()
0 commit comments