@@ -695,12 +695,39 @@ def test_no_init_nuts_compound(caplog):
695695
696696
697697def test_sample_var_names ():
698- with pm .Model () as model :
699- a = pm .Normal ("a" )
700- b = pm .Deterministic ("b" , a ** 2 )
701- idata = pm .sample (10 , tune = 10 , var_names = ["a" ])
702- assert "a" in idata .posterior
703- assert "b" not in idata .posterior
698+ # Generate data
699+ seed = 1234
700+ rng = np .random .default_rng (seed )
701+
702+ group = rng .choice (list ("ABCD" ), size = 100 )
703+ x = rng .normal (size = 100 )
704+ y = rng .normal (size = 100 )
705+
706+ group_values , group_idx = np .unique (group , return_inverse = True )
707+
708+ coords = {"group" : group_values }
709+
710+ # Create model
711+ with pm .Model (coords = coords ) as model :
712+ b_group = pm .Normal ("b_group" , dims = "group" )
713+ b_x = pm .Normal ("b_x" )
714+ mu = pm .Deterministic ("mu" , b_group [group_idx ] + b_x * x )
715+ sigma = pm .HalfNormal ("sigma" )
716+ pm .Normal ("y" , mu = mu , sigma = sigma , observed = y )
717+
718+ # Sample with and without var_names, but always with the same seed
719+ with model :
720+ idata_1 = pm .sample (tune = 100 , draws = 100 , random_seed = seed )
721+ idata_2 = pm .sample (
722+ tune = 100 , draws = 100 , var_names = ["b_group" , "b_x" , "sigma" ], random_seed = seed
723+ )
724+
725+ assert "mu" in idata_1 .posterior
726+ assert "mu" not in idata_2 .posterior
727+
728+ assert np .all (idata_1 .posterior ["b_group" ] == idata_2 .posterior ["b_group" ]).item ()
729+ assert np .all (idata_1 .posterior ["b_x" ] == idata_2 .posterior ["b_x" ]).item ()
730+ assert np .all (idata_1 .posterior ["sigma" ] == idata_2 .posterior ["sigma" ]).item ()
704731
705732
706733class TestAssignStepMethods :
0 commit comments