@@ -395,44 +395,38 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
395395def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
396396 # Clone the inner RV graph of the Marginalized RV
397397 marginalized_rvs_node = op .make_node (* inputs )
398- marginalized_rv , * dependent_rvs = clone_replace (
398+ inner_rvs = clone_replace (
399399 op .inner_outputs ,
400400 replace = {u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node .inputs )},
401401 )
402+ marginalized_rv = inner_rvs [0 ]
402403
403404 # Obtain the joint_logp graph of the inner RV graph
404- # Some inputs are not root inputs (such as transformed projections of value variables)
405- # Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
406- inputs = list (inputvars (inputs ))
407- rvs_to_values = {}
408- dummy_marginalized_value = marginalized_rv .clone ()
409- rvs_to_values [marginalized_rv ] = dummy_marginalized_value
410- rvs_to_values .update (zip (dependent_rvs , values ))
411- logps_dict = factorized_joint_logprob (rv_values = rvs_to_values , ** kwargs )
405+ inner_rvs_to_values = {rv : rv .clone () for rv in inner_rvs }
406+ logps_dict = factorized_joint_logprob (rv_values = inner_rvs_to_values , ** kwargs )
412407
413408 # Reduce logp dimensions corresponding to broadcasted variables
414- values_axis_bcast = []
415- for value in values :
416- vbcast = value .type .broadcastable
417- mbcast = dummy_marginalized_value .type .broadcastable
409+ joint_logp = logps_dict [inner_rvs_to_values [marginalized_rv ]]
410+ for inner_rv , inner_value in inner_rvs_to_values .items ():
411+ if inner_rv is marginalized_rv :
412+ continue
413+ vbcast = inner_value .type .broadcastable
414+ mbcast = marginalized_rv .type .broadcastable
418415 mbcast = (True ,) * (len (vbcast ) - len (mbcast )) + mbcast
419- values_axis_bcast .append ([i for i , (m , v ) in enumerate (zip (mbcast , vbcast )) if m != v ])
420- joint_logp = logps_dict [dummy_marginalized_value ]
421- for value , values_axis_bcast in zip (values , values_axis_bcast ):
422- joint_logp += logps_dict [value ].sum (values_axis_bcast , keepdims = True )
416+ values_axis_bcast = [i for i , (m , v ) in enumerate (zip (mbcast , vbcast )) if m != v ]
417+ joint_logp += logps_dict [inner_value ].sum (values_axis_bcast , keepdims = True )
423418
424419 # Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
425420 # values of the marginalized RV
426- # OpFromGraph does not accept constant inputs
427- non_const_values = [
428- value
429- for value in rvs_to_values .values ()
430- if not isinstance (value , (Constant , SharedVariable ))
431- ]
432- joint_logp_op = OpFromGraph ([* non_const_values , * inputs ], [joint_logp ], inline = True )
421+ # Some inputs are not root inputs (such as transformed projections of value variables)
422+ # Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
423+ inputs = list (inputvars (inputs ))
424+ joint_logp_op = OpFromGraph (
425+ list (inner_rvs_to_values .values ()) + inputs , [joint_logp ], inline = True
426+ )
433427
434428 # Compute the joint_logp for all possible n values of the marginalized RV. We assume
435- # each original dimension is independent so that it sufficies to evaluate the graph
429+ # each original dimension is independent so that it suffices to evaluate the graph
436430 # n times, once with each possible value of the marginalized RV replicated across
437431 # batched dimensions of the marginalized RV
438432
@@ -449,18 +443,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
449443 axis2 = - 1 ,
450444 )
451445
452- # OpFromGraph does not accept constant inputs
453- non_const_values = [
454- value for value in values if not isinstance (value , (Constant , SharedVariable ))
455- ]
456446 # Arbitrary cutoff to switch to Scan implementation to keep graph size under control
457447 if len (marginalized_rv_domain ) <= 10 :
458448 joint_logps = [
459- joint_logp_op (marginalized_rv_domain_tensor [i ], * non_const_values , * inputs )
449+ joint_logp_op (marginalized_rv_domain_tensor [i ], * values , * inputs )
460450 for i in range (len (marginalized_rv_domain ))
461451 ]
462452 else :
463- # Make sure this is rewrite is registered
453+ # Make sure this rewrite is registered
464454 from pymc .pytensorf import local_remove_check_parameter
465455
466456 def logp_fn (marginalized_rv_const , * non_sequences ):
@@ -469,7 +459,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
469459 joint_logps , _ = scan_map (
470460 fn = logp_fn ,
471461 sequences = marginalized_rv_domain_tensor ,
472- non_sequences = [* non_const_values , * inputs ],
462+ non_sequences = [* values , * inputs ],
473463 mode = Mode ().including ("local_remove_check_parameter" ),
474464 )
475465
0 commit comments