File tree Expand file tree Collapse file tree 2 files changed +29
-3
lines changed
Expand file tree Collapse file tree 2 files changed +29
-3
lines changed Original file line number Diff line number Diff line change 3939 Apply ,
4040 Constant ,
4141 Variable ,
42- ancestors ,
4342 clone_get_equiv ,
4443 graph_inputs ,
44+ vars_between ,
4545 walk ,
4646)
4747from aesara .graph .fg import FunctionGraph
@@ -975,8 +975,8 @@ def compile_pymc(
975975 output_to_list = outputs if isinstance (outputs , (list , tuple )) else [outputs ]
976976 for rv in (
977977 node
978- for node in ancestors ( output_to_list )
979- if node .owner and isinstance (node .owner .op , RandomVariable )
978+ for node in vars_between ( inputs , output_to_list )
979+ if node .owner and isinstance (node .owner .op , RandomVariable ) and node not in inputs
980980 ):
981981 rng = rv .owner .inputs [0 ]
982982 if not hasattr (rng , "default_update" ):
Original file line number Diff line number Diff line change @@ -636,3 +636,29 @@ def test_compile_pymc_missing_default_explicit_updates():
636636 # And again, it should be overridden by an explicit update
637637 f = compile_pymc ([], x , updates = {rng : x .owner .outputs [0 ]})
638638 assert f () != f ()
639+
640+
641+ def test_compile_pymc_updates_inputs ():
642+ """Test that compile_pymc does not include rngs updates of variables that are inputs
643+ or ancestors to inputs
644+ """
645+ x = at .random .normal ()
646+ y = at .random .normal (x )
647+ z = at .random .normal (y )
648+
649+ for inputs , rvs_in_graph in (
650+ ([], 3 ),
651+ ([x ], 2 ),
652+ ([y ], 1 ),
653+ ([z ], 0 ),
654+ ([x , y ], 1 ),
655+ ([x , y , z ], 0 ),
656+ ):
657+ fn = compile_pymc (inputs , z , on_unused_input = "ignore" )
658+ fn_fgraph = fn .maker .fgraph
659+ # Each RV adds a shared input for its rng
660+ assert len (fn_fgraph .inputs ) == len (inputs ) + rvs_in_graph
661+ # If the output is an input, the graph has a DeepCopyOp
662+ assert len (fn_fgraph .apply_nodes ) == max (rvs_in_graph , 1 )
663+ # Each RV adds a shared output for its rng
664+ assert len (fn_fgraph .outputs ) == 1 + rvs_in_graph
You can’t perform that action at this time.
0 commit comments