@@ -1048,7 +1048,11 @@ def change_custom_dist_size(op, rv, new_size, expand):
10481048
10491049 return new_rv
10501050
1051- rngs , rngs_updates = zip (* dummy_updates_dict .items ())
1051+ if dummy_updates_dict :
1052+ rngs , rngs_updates = zip (* dummy_updates_dict .items ())
1053+ else :
1054+ rngs , rngs_updates = (), ()
1055+
10521056 inputs = [* dummy_params , * rngs ]
10531057 outputs = [dummy_rv , * rngs_updates ]
10541058 signature = cls ._infer_final_signature (
@@ -1497,19 +1501,26 @@ def default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False
14971501 )
14981502
14991503
1500- class DiracDeltaRV (RandomVariable ):
1504+ class DiracDeltaRV (SymbolicRandomVariable ):
15011505 name = "diracdelta"
1502- signature = "()->()"
1506+ extended_signature = "[size], ()->()"
15031507 _print_name = ("DiracDelta" , "\\ operatorname{DiracDelta}" )
15041508
1509+ def do_constant_folding (self , fgraph : "FunctionGraph" , node : Apply ) -> bool :
1510+ # Because the distribution does not have RNGs we have to prevent constant-folding
1511+ return False
1512+
15051513 @classmethod
1506- def rng_fn (cls , rng , c , size = None ):
1507- if size is None :
1508- return c .copy ()
1509- return np .full (size , c )
1514+ def rv_op (cls , c , * , size = None , rng = None ):
1515+ size = normalize_size_param (size )
1516+ c = pt .as_tensor (c )
15101517
1518+ if rv_size_is_none (size ):
1519+ out = c .copy ()
1520+ else :
1521+ out = pt .full (size , c )
15111522
1512- diracdelta = DiracDeltaRV ( )
1523+ return cls ( inputs = [ size , c ], outputs = [ out ])( size , c )
15131524
15141525
15151526class DiracDelta (Discrete ):
@@ -1524,14 +1535,15 @@ class DiracDelta(Discrete):
15241535 that use DiracDelta, such as Mixtures.
15251536 """
15261537
1527- rv_op = diracdelta
1538+ rv_type = DiracDeltaRV
1539+ rv_op = DiracDeltaRV .rv_op
15281540
15291541 @classmethod
15301542 def dist (cls , c , * args , ** kwargs ):
15311543 c = pt .as_tensor_variable (c )
15321544 if c .dtype in continuous_types :
15331545 c = floatX (c )
1534- return super ().dist ([c ], dtype = c . dtype , ** kwargs )
1546+ return super ().dist ([c ], ** kwargs )
15351547
15361548 def support_point (rv , size , c ):
15371549 if not rv_size_is_none (size ):
0 commit comments