5353from pymc .step_methods import Metropolis
5454from pymc .testing import assert_support_point_is_expected
5555
56+ # Raise for any warnings in this file
57+ pytestmark = pytest .mark .filterwarnings ("error" )
58+
5659
5760class TestCustomDist :
5861 @pytest .mark .parametrize ("size" , [(), (3 ,), (3 , 2 )], ids = str )
@@ -105,24 +108,24 @@ def test_custom_dist_without_random(self):
105108 with pytest .raises (NotImplementedError ):
106109 sample_posterior_predictive (idata , model = model )
107110
108- @pytest .mark .xfail (
109- NotImplementedError ,
110- reason = "Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388" ,
111- )
112111 @pytest .mark .parametrize ("size" , [(), (3 ,), (3 , 2 )], ids = str )
113112 def test_custom_dist_with_random_multivariate (self , size ):
113+ def random (mu , rng , size ):
114+ return rng .multivariate_normal (
115+ mean = mu .ravel (),
116+ cov = np .eye (mu .shape [- 1 ]),
117+ size = size ,
118+ )
119+
114120 supp_shape = 5
115121 with Model () as model :
116122 mu = Normal ("mu" , 0 , 1 , size = supp_shape )
117123 obs = CustomDist (
118124 "custom_dist" ,
119125 mu ,
120- random = lambda mu , rng = None , size = None : rng .multivariate_normal (
121- mean = mu , cov = np .eye (len (mu )), size = size
122- ),
126+ random = random ,
123127 observed = np .random .randn (100 , * size , supp_shape ),
124- ndims_params = [1 ],
125- ndim_supp = 1 ,
128+ signature = "(n)->(n)" ,
126129 )
127130
128131 assert isinstance (obs .owner .op , CustomDistRV )
@@ -156,20 +159,16 @@ def test_custom_dist_old_api_error(self):
156159 ):
157160 CustomDist ("a" , lambda x : x )
158161
159- @pytest .mark .xfail (
160- NotImplementedError ,
161- reason = "Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388" ,
162- )
163162 @pytest .mark .parametrize ("size" , [None , (), (2 ,)], ids = str )
164163 def test_custom_dist_multivariate_logp (self , size ):
165164 supp_shape = 5
166165 with Model () as model :
167166
168167 def logp (value , mu ):
169- return MvNormal .logp (value , mu , pt .eye (mu .shape [0 ]))
168+ return MvNormal .logp (value , mu , pt .eye (mu .shape [- 1 ]))
170169
171170 mu = Normal ("mu" , size = supp_shape )
172- a = CustomDist ("a" , mu , logp = logp , ndims_params = [ 1 ], ndim_supp = 1 , size = size )
171+ a = CustomDist ("a" , mu , logp = logp , signature = "(n)->(n)" , size = size )
173172
174173 assert isinstance (a .owner .op , CustomDistRV )
175174 mu_test_value = npr .normal (loc = 0 , scale = 1 , size = supp_shape ).astype (pytensor .config .floatX )
@@ -219,10 +218,6 @@ def density_support_point(rv, size, mu):
219218 assert evaled_support_point .shape == to_tuple (size )
220219 assert np .all (evaled_support_point == mu_val )
221220
222- @pytest .mark .xfail (
223- NotImplementedError ,
224- reason = "Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388" ,
225- )
226221 @pytest .mark .parametrize ("size" , [(), (2 ,), (3 , 2 )], ids = str )
227222 def test_custom_dist_custom_support_point_multivariate (self , size ):
228223 def density_support_point (rv , size , mu ):
@@ -235,19 +230,14 @@ def density_support_point(rv, size, mu):
235230 "a" ,
236231 mu ,
237232 support_point = density_support_point ,
238- ndims_params = [1 ],
239- ndim_supp = 1 ,
233+ signature = "(n)->(n)" ,
240234 size = size ,
241235 )
242236 assert isinstance (a .owner .op , CustomDistRV )
243237 evaled_support_point = support_point (a ).eval ({mu : mu_val })
244238 assert evaled_support_point .shape == (* to_tuple (size ), 5 )
245239 assert np .all (evaled_support_point == mu_val )
246240
247- @pytest .mark .xfail (
248- NotImplementedError ,
249- reason = "Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388" ,
250- )
251241 @pytest .mark .parametrize (
252242 "with_random, size" ,
253243 [
@@ -267,21 +257,14 @@ def _random(mu, rng=None, size=None):
267257 else :
268258 random = None
269259
270- mu_val = np .random .normal (loc = 2 , scale = 1 , size = 5 ).astype (pytensor .config .floatX )
271260 with Model ():
272261 mu = Normal ("mu" , size = 5 )
273- a = CustomDist ("a" , mu , random = random , ndims_params = [ 1 ], ndim_supp = 1 , size = size )
262+ a = CustomDist ("a" , mu , random = random , signature = "(n)->(n)" , size = size )
274263 assert isinstance (a .owner .op , CustomDistRV )
275264 if with_random :
276- evaled_support_point = support_point (a ).eval ({ mu : mu_val } )
265+ evaled_support_point = support_point (a ).eval ()
277266 assert evaled_support_point .shape == (* to_tuple (size ), 5 )
278267 assert np .all (evaled_support_point == 0 )
279- else :
280- with pytest .raises (
281- TypeError ,
282- match = "Cannot safely infer the size of a multivariate random variable's support_point." ,
283- ):
284- evaled_support_point = support_point (a ).eval ({mu : mu_val })
285268
286269 def test_dist (self ):
287270 mu = 1
@@ -300,6 +283,12 @@ def test_dist(self):
300283 x_logp = logp (x , test_value )
301284 assert np .allclose (x_logp .eval (), st .norm (1 ).logpdf (test_value ))
302285
286+ def test_multivariate_insufficient_signature (self ):
287+ with pytest .raises (
288+ NotImplementedError , match = "signature is not sufficient to infer the support shape"
289+ ):
290+ CustomDist .dist (signature = "(n)->(m)" )
291+
303292
304293class TestCustomSymbolicDist :
305294 def test_basic (self ):
0 commit comments