2323from pytensor import tensor as pt
2424from scipy import stats as st
2525
26- import pymc as pm
27-
28- from pymc import (
29- CustomDist ,
30- Deterministic ,
26+ from pymc . distributions import (
27+ Bernoulli ,
28+ Beta ,
29+ Categorical ,
30+ ChiSquared ,
3131 DiracDelta ,
32+ Flat ,
3233 HalfNormal ,
3334 LogNormal ,
34- Model ,
35+ Mixture ,
36+ MvNormal ,
3537 Normal ,
36- draw ,
37- logcdf ,
38- logp ,
39- sample ,
38+ NormalMixture ,
39+ RandomWalk ,
40+ StudentT ,
41+ Truncated ,
42+ Uniform ,
4043)
41- from pymc .distributions .custom import CustomDistRV , CustomSymbolicDistRV
44+ from pymc .distributions .custom import CustomDist , CustomDistRV , CustomSymbolicDistRV
4245from pymc .distributions .distribution import support_point
4346from pymc .distributions .shape_utils import change_dist_size , rv_size_is_none , to_tuple
4447from pymc .distributions .transforms import log
4548from pymc .exceptions import BlockModelAccessError
49+ from pymc .logprob import logcdf , logp
50+ from pymc .model import Deterministic , Model
4651from pymc .pytensorf import collect_default_updates
52+ from pymc .sampling import draw , sample , sample_posterior_predictive
53+ from pymc .step_methods import Metropolis
4754from pymc .testing import assert_support_point_is_expected
4855
4956
@@ -88,15 +95,15 @@ def test_custom_dist_without_random(self):
8895 custom_dist = CustomDist (
8996 "custom_dist" ,
9097 mu ,
91- logp = lambda value , mu : logp (pm . Normal .dist (mu , 1 , size = 100 ), value ),
98+ logp = lambda value , mu : logp (Normal .dist (mu , 1 , size = 100 ), value ),
9299 observed = np .random .randn (100 ),
93100 initval = 0 ,
94101 )
95102 assert isinstance (custom_dist .owner .op , CustomDistRV )
96- idata = sample (tune = 50 , draws = 100 , cores = 1 , step = pm . Metropolis ())
103+ idata = sample (tune = 50 , draws = 100 , cores = 1 , step = Metropolis ())
97104
98105 with pytest .raises (NotImplementedError ):
99- pm . sample_posterior_predictive (idata , model = model )
106+ sample_posterior_predictive (idata , model = model )
100107
101108 @pytest .mark .xfail (
102109 NotImplementedError ,
@@ -159,7 +166,7 @@ def test_custom_dist_multivariate_logp(self, size):
159166 with Model () as model :
160167
161168 def logp (value , mu ):
162- return pm . MvNormal .logp (value , mu , pt .eye (mu .shape [0 ]))
169+ return MvNormal .logp (value , mu , pt .eye (mu .shape [0 ]))
163170
164171 mu = Normal ("mu" , size = supp_shape )
165172 a = CustomDist ("a" , mu , logp = logp , ndims_params = [1 ], ndim_supp = 1 , size = size )
@@ -184,14 +191,14 @@ def logp(value, mu):
184191 def test_custom_dist_default_support_point_univariate (self , support_point , size , expected ):
185192 if support_point == "custom_support_point" :
186193 support_point = lambda rv , size , * rv_inputs : 5 * pt .ones (size , dtype = rv .dtype ) # noqa E731
187- with pm . Model () as model :
194+ with Model () as model :
188195 x = CustomDist ("x" , support_point = support_point , size = size )
189196 assert isinstance (x .owner .op , CustomDistRV )
190197 assert_support_point_is_expected (model , expected , check_finite_logp = False )
191198
192199 def test_custom_dist_moment_future_warning (self ):
193200 moment = lambda rv , size , * rv_inputs : 5 * pt .ones (size , dtype = rv .dtype ) # noqa E731
194- with pm . Model () as model :
201+ with Model () as model :
195202 with pytest .warns (
196203 FutureWarning , match = "`moment` argument is deprecated. Use `support_point` instead."
197204 ):
@@ -280,24 +287,24 @@ def test_dist(self):
280287 mu = 1
281288 x = CustomDist .dist (
282289 mu ,
283- logp = lambda value , mu : pm . logp (pm . Normal .dist (mu ), value ),
290+ logp = lambda value , mu : logp (Normal .dist (mu ), value ),
284291 random = lambda mu , rng = None , size = None : rng .normal (loc = mu , scale = 1 , size = size ),
285292 shape = (3 ,),
286293 )
287294
288295 x = cloudpickle .loads (cloudpickle .dumps (x ))
289296
290- test_value = pm . draw (x , random_seed = 1 )
291- assert np .all (test_value == pm . draw (x , random_seed = 1 ))
297+ test_value = draw (x , random_seed = 1 )
298+ assert np .all (test_value == draw (x , random_seed = 1 ))
292299
293- x_logp = pm . logp (x , test_value )
300+ x_logp = logp (x , test_value )
294301 assert np .allclose (x_logp .eval (), st .norm (1 ).logpdf (test_value ))
295302
296303
297304class TestCustomSymbolicDist :
298305 def test_basic (self ):
299306 def custom_dist (mu , sigma , size ):
300- return pt .exp (pm . Normal .dist (mu , sigma , size = size ))
307+ return pt .exp (Normal .dist (mu , sigma , size = size ))
301308
302309 with Model () as m :
303310 mu = Normal ("mu" )
@@ -315,7 +322,7 @@ def custom_dist(mu, sigma, size):
315322 assert isinstance (lognormal .owner .op , CustomSymbolicDistRV )
316323
317324 # Fix mu and sigma, so that all source of randomness comes from the symbolic RV
318- draws = pm . draw (lognormal , draws = 3 , givens = {mu : 0.0 , sigma : 1.0 })
325+ draws = draw (lognormal , draws = 3 , givens = {mu : 0.0 , sigma : 1.0 })
319326 assert draws .shape == (3 , 10 )
320327 assert np .unique (draws ).size == 30
321328
@@ -334,31 +341,31 @@ def custom_dist(mu, sigma, size):
334341 (5 , 1 ),
335342 None ,
336343 np .exp (5 ),
337- lambda mu , sigma , size : pt .exp (pm . Normal .dist (mu , sigma , size = size )),
344+ lambda mu , sigma , size : pt .exp (Normal .dist (mu , sigma , size = size )),
338345 ),
339346 (
340347 (2 , np .ones (5 )),
341348 None ,
342349 np .exp (2 + np .ones (5 )),
343- lambda mu , sigma , size : pt .exp (pm . Normal .dist (mu , sigma , size = size ) + 1.0 ),
350+ lambda mu , sigma , size : pt .exp (Normal .dist (mu , sigma , size = size ) + 1.0 ),
344351 ),
345352 (
346353 (1 , 2 ),
347354 None ,
348355 np .sqrt (np .exp (1 + 0.5 * 2 ** 2 )),
349- lambda mu , sigma , size : pt .sqrt (pm . LogNormal .dist (mu , sigma , size = size )),
356+ lambda mu , sigma , size : pt .sqrt (LogNormal .dist (mu , sigma , size = size )),
350357 ),
351358 (
352359 (4 ,),
353360 (3 ,),
354361 np .log ([4 , 4 , 4 ]),
355- lambda nu , size : pt .log (pm . ChiSquared .dist (nu , size = size )),
362+ lambda nu , size : pt .log (ChiSquared .dist (nu , size = size )),
356363 ),
357364 (
358365 (12 , 1 ),
359366 None ,
360367 12 ,
361- lambda mu1 , sigma , size : pm . Normal .dist (mu1 , sigma , size = size ),
368+ lambda mu1 , sigma , size : Normal .dist (mu1 , sigma , size = size ),
362369 ),
363370 ],
364371 )
@@ -369,7 +376,7 @@ def test_custom_dist_default_support_point(self, dist_params, size, expected, di
369376
370377 def test_custom_dist_default_support_point_scan (self ):
371378 def scan_step (left , right ):
372- x = pm . Uniform .dist (left , right )
379+ x = Uniform .dist (left , right )
373380 x_update = collect_default_updates ([x ])
374381 return x , x_update
375382
@@ -390,7 +397,7 @@ def dist(size):
390397
391398 def test_custom_dist_default_support_point_scan_recurring (self ):
392399 def scan_step (xtm1 ):
393- x = pm . Normal .dist (xtm1 + 1 )
400+ x = Normal .dist (xtm1 + 1 )
394401 x_update = collect_default_updates ([x ])
395402 return x , x_update
396403
@@ -417,15 +424,15 @@ def dist(size):
417424 )
418425 def test_custom_dist_default_support_point_nested (self , left , right , size , expected ):
419426 def dist_fn (left , right , size ):
420- return pm . Truncated .dist (pm . Normal .dist (0 , 1 ), left , right , size = size ) + 5
427+ return Truncated .dist (Normal .dist (0 , 1 ), left , right , size = size ) + 5
421428
422429 with Model () as model :
423430 CustomDist ("x" , left , right , size = size , dist = dist_fn )
424431 assert_support_point_is_expected (model , expected )
425432
426433 def test_logcdf_inference (self ):
427434 def custom_dist (mu , sigma , size ):
428- return pt .exp (pm . Normal .dist (mu , sigma , size = size ))
435+ return pt .exp (Normal .dist (mu , sigma , size = size ))
429436
430437 mu = 1
431438 sigma = 1.25
@@ -435,16 +442,16 @@ def custom_dist(mu, sigma, size):
435442 ref_lognormal = LogNormal .dist (mu , sigma )
436443
437444 np .testing .assert_allclose (
438- pm . logcdf (custom_lognormal , test_value ).eval (),
439- pm . logcdf (ref_lognormal , test_value ).eval (),
445+ logcdf (custom_lognormal , test_value ).eval (),
446+ logcdf (ref_lognormal , test_value ).eval (),
440447 )
441448
442449 def test_random_multiple_rngs (self ):
443450 def custom_dist (p , sigma , size ):
444- idx = pm . Bernoulli .dist (p = p )
451+ idx = Bernoulli .dist (p = p )
445452 if rv_size_is_none (size ):
446453 size = pt .broadcast_shape (p , sigma )
447- comps = pm . Normal .dist ([- sigma , sigma ], 1e-1 , size = (* size , 2 )).T
454+ comps = Normal .dist ([- sigma , sigma ], 1e-1 , size = (* size , 2 )).T
448455 return comps [idx ]
449456
450457 customdist = CustomDist .dist (
@@ -461,7 +468,7 @@ def custom_dist(p, sigma, size):
461468 assert len (node .outputs ) == 3 # RV and 2 updated RNGs
462469 assert len (node .op .update (node )) == 2
463470
464- draws = pm . draw (customdist , draws = 2 , random_seed = 123 )
471+ draws = draw (customdist , draws = 2 , random_seed = 123 )
465472 assert np .unique (draws ).size == 20
466473
467474 def test_custom_methods (self ):
@@ -494,7 +501,7 @@ def custom_logcdf(value, mu):
494501
495502 def test_change_size (self ):
496503 def custom_dist (mu , sigma , size ):
497- return pt .exp (pm . Normal .dist (mu , sigma , size = size ))
504+ return pt .exp (Normal .dist (mu , sigma , size = size ))
498505
499506 lognormal = CustomDist .dist (
500507 0 ,
@@ -515,9 +522,9 @@ def custom_dist(mu, sigma, size):
515522
516523 def test_error_model_access (self ):
517524 def custom_dist (size ):
518- return pm . Flat ("Flat" , size = size )
525+ return Flat ("Flat" , size = size )
519526
520- with pm . Model () as m :
527+ with Model () as m :
521528 with pytest .raises (
522529 BlockModelAccessError ,
523530 match = "Model variables cannot be created in the dist function" ,
@@ -526,7 +533,7 @@ def custom_dist(size):
526533
527534 def test_api_change_error (self ):
528535 def old_random (size ):
529- return pm . Flat .dist (size = size )
536+ return Flat .dist (size = size )
530537
531538 # Old API raises
532539 with pytest .raises (TypeError , match = "API change: function passed to `random` argument" ):
@@ -541,7 +548,7 @@ def trw(nu, sigma, steps, size):
541548 size = ()
542549
543550 def step (xtm1 , nu , sigma ):
544- x = pm . StudentT .dist (nu = nu , mu = xtm1 , sigma = sigma , shape = size )
551+ x = StudentT .dist (nu = nu , mu = xtm1 , sigma = sigma , shape = size )
545552 return x , collect_default_updates ([x ])
546553
547554 xs , _ = scan (
@@ -562,52 +569,50 @@ def step(xtm1, nu, sigma):
562569 batch_size = 3
563570 x = CustomDist .dist (nu , sigma , steps , dist = trw , size = batch_size )
564571
565- x_draw = pm . draw (x , random_seed = 1 )
572+ x_draw = draw (x , random_seed = 1 )
566573 assert x_draw .shape == (steps , batch_size )
567- np .testing .assert_allclose (pm . draw (x , random_seed = 1 ), x_draw )
568- assert not np .any (pm . draw (x , random_seed = 2 ) == x_draw )
574+ np .testing .assert_allclose (draw (x , random_seed = 1 ), x_draw )
575+ assert not np .any (draw (x , random_seed = 2 ) == x_draw )
569576
570- ref_dist = pm . RandomWalk .dist (
571- init_dist = pm . Flat .dist (),
572- innovation_dist = pm . StudentT .dist (nu = nu , sigma = sigma ),
577+ ref_dist = RandomWalk .dist (
578+ init_dist = Flat .dist (),
579+ innovation_dist = StudentT .dist (nu = nu , sigma = sigma ),
573580 steps = steps ,
574581 size = (batch_size ,),
575582 )
576583 ref_val = pt .concatenate ([np .zeros ((1 , batch_size )), x_draw ]).T
577584
578585 np .testing .assert_allclose (
579- pm . logp (x , x_draw ).eval ().sum (0 ),
580- pm . logp (ref_dist , ref_val ).eval (),
586+ logp (x , x_draw ).eval ().sum (0 ),
587+ logp (ref_dist , ref_val ).eval (),
581588 )
582589
583590 def test_inferred_logp_mixture (self ):
584591 import numpy as np
585592
586- import pymc as pm
587-
588593 def shifted_normal (mu , sigma , size ):
589- return mu + pm . Normal .dist (0 , sigma , shape = size )
594+ return mu + Normal .dist (0 , sigma , shape = size )
590595
591596 mus = [3.5 , - 4.3 ]
592597 sds = [1.5 , 2.3 ]
593598 w = [0.3 , 0.7 ]
594- with pm . Model () as m :
599+ with Model () as m :
595600 comp_dists = [
596601 CustomDist .dist (mus [0 ], sds [0 ], dist = shifted_normal ),
597602 CustomDist .dist (mus [1 ], sds [1 ], dist = shifted_normal ),
598603 ]
599- pm . Mixture ("mix" , w = w , comp_dists = comp_dists )
604+ Mixture ("mix" , w = w , comp_dists = comp_dists )
600605
601606 test_value = 0.1
602607 np .testing .assert_allclose (
603608 m .compile_logp ()({"mix" : test_value }),
604- pm . logp (pm . NormalMixture .dist (w = w , mu = mus , sigma = sds ), test_value ).eval (),
609+ logp (NormalMixture .dist (w = w , mu = mus , sigma = sds ), test_value ).eval (),
605610 )
606611
607612 def test_symbolic_dist (self ):
608613 # Test we can create a SymbolicDist inside a CustomDist
609614 def dist (size ):
610- return pm . Truncated .dist (pm . Beta .dist (1 , 1 , size = size ), lower = 0.1 , upper = 0.9 )
615+ return Truncated .dist (Beta .dist (1 , 1 , size = size ), lower = 0.1 , upper = 0.9 )
611616
612617 assert CustomDist .dist (dist = dist )
613618
@@ -616,20 +621,20 @@ def test_nested_custom_dist(self):
616621
617622 def dist (size = None ):
618623 def inner_dist (size = None ):
619- return pm . Normal .dist (size = size )
624+ return Normal .dist (size = size )
620625
621626 inner_dist = CustomDist .dist (dist = inner_dist , size = size )
622627 return pt .exp (inner_dist )
623628
624629 rv = CustomDist .dist (dist = dist )
625630 np .testing .assert_allclose (
626- pm . logp (rv , 1.0 ).eval (),
627- pm . logp (pm . LogNormal .dist (), 1.0 ).eval (),
631+ logp (rv , 1.0 ).eval (),
632+ logp (LogNormal .dist (), 1.0 ).eval (),
628633 )
629634
630635 def test_signature (self ):
631636 def dist (p , size ):
632- return - pm . Categorical .dist (p = p , size = size )
637+ return - Categorical .dist (p = p , size = size )
633638
634639 out = CustomDist .dist ([0.25 , 0.75 ], dist = dist , signature = "(p)->()" )
635640 # Size and updates are added automatically to the signature
0 commit comments