|
30 | 30 | from aesara.tensor.basic import as_tensor_variable |
31 | 31 | from aesara.tensor.elemwise import Elemwise |
32 | 32 | from aesara.tensor.random.op import RandomVariable |
33 | | -from aesara.tensor.random.var import RandomStateSharedVariable |
34 | 33 | from aesara.tensor.var import TensorVariable |
35 | 34 | from typing_extensions import TypeAlias |
36 | 35 |
|
@@ -358,23 +357,6 @@ def dist( |
358 | 357 | replicate_shape = cast(StrongShape, shape[:-1]) |
359 | 358 | rv_out = change_rv_size(rv_var=rv_out, new_size=replicate_shape, expand=True) |
360 | 359 |
|
361 | | - rng = kwargs.pop("rng", None) |
362 | | - if ( |
363 | | - rv_out.owner |
364 | | - and isinstance(rv_out.owner.op, RandomVariable) |
365 | | - and isinstance(rng, RandomStateSharedVariable) |
366 | | - and not getattr(rng, "default_update", None) |
367 | | - ): |
368 | | - # This tells `aesara.function` that the shared RNG variable |
369 | | - # is mutable, which--in turn--tells the `FunctionGraph` |
370 | | - # `Supervisor` feature to allow in-place updates on the variable. |
371 | | - # Without it, the `RandomVariable`s could not be optimized to allow |
372 | | - # in-place RNG updates, forcing all sample results from compiled |
373 | | - # functions to be the same on repeated evaluations. |
374 | | - new_rng = rv_out.owner.outputs[0] |
375 | | - rv_out.update = (rng, new_rng) |
376 | | - rng.default_update = new_rng |
377 | | - |
378 | 360 | rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)") |
379 | 361 | rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)") |
380 | 362 | rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()") |
@@ -589,27 +571,6 @@ def dist( |
589 | 571 | replicate_shape = cast(StrongShape, shape[:-1]) |
590 | 572 | graph = cls.change_size(rv=graph, new_size=replicate_shape, expand=True) |
591 | 573 |
|
592 | | - rngs = kwargs.pop("rngs", None) |
593 | | - if rngs is not None: |
594 | | - graph_rvs = cls.graph_rvs(graph) |
595 | | - assert len(rngs) == len(graph_rvs) |
596 | | - for rng, rv_out in zip(rngs, graph_rvs): |
597 | | - if ( |
598 | | - rv_out.owner |
599 | | - and isinstance(rv_out.owner.op, RandomVariable) |
600 | | - and isinstance(rng, RandomStateSharedVariable) |
601 | | - and not getattr(rng, "default_update", None) |
602 | | - ): |
603 | | - # This tells `aesara.function` that the shared RNG variable |
604 | | - # is mutable, which--in turn--tells the `FunctionGraph` |
605 | | - # `Supervisor` feature to allow in-place updates on the variable. |
606 | | - # Without it, the `RandomVariable`s could not be optimized to allow |
607 | | - # in-place RNG updates, forcing all sample results from compiled |
608 | | - # functions to be the same on repeated evaluations. |
609 | | - new_rng = rv_out.owner.outputs[0] |
610 | | - rv_out.update = (rng, new_rng) |
611 | | - rng.default_update = new_rng |
612 | | - |
613 | 574 | # TODO: Create new attr error stating that these are not available for DerivedDistribution |
614 | 575 | # rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)") |
615 | 576 | # rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)") |
|
0 commit comments