1515# limitations under the License.
1616
1717from multiprocessing import Manager
18- import aesara .tensor as at
1918import numpy as np
20-
21- from aeppl .logprob import _logprob
22- from aesara .tensor .random .op import RandomVariable
23- from aesara .tensor .var import Variable
24-
2519from pandas import DataFrame , Series
2620
2721from pymc .distributions .distribution import Distribution , _moment
22+ from pymc .logprob .abstract import _logprob
23+ import pytensor .tensor as pt
24+ from pytensor .tensor .random .op import RandomVariable
25+
2826
2927from .utils import _sample_posterior
3028
@@ -42,11 +40,7 @@ class BARTRV(RandomVariable):
4240 all_trees = None
4341
4442 def _supp_shape_from_params (self , dist_params , rep_param_idx = 1 , param_shapes = None ):
45- if isinstance (self .X , Variable ):
46- shape = self .X .shape [0 ].eval ()
47- else :
48- shape = self .X .shape [0 ]
49- return (shape ,)
43+ return dist_params [0 ].shape [:1 ]
5044
5145 @classmethod
5246 def rng_fn (cls , rng = None , X = None , Y = None , m = None , alpha = None , split_prior = None , size = None ):
@@ -145,11 +139,11 @@ def logp(self, x, *inputs):
145139 -------
146140 TensorVariable
147141 """
148- return at .zeros_like (x )
142+ return pt .zeros_like (x )
149143
150144 @classmethod
151145 def get_moment (cls , rv , size , * rv_inputs ):
152- mean = at .fill (size , rv .Y .mean ())
146+ mean = pt .fill (size , rv .Y .mean ())
153147 return mean
154148
155149
0 commit comments