@@ -172,24 +172,23 @@ def func(x):
172172 trace = pm .sample (draws = 10 , tune = 10 , step = pm .Metropolis (), cores = 2 , mp_ctx = "spawn" )
173173
174174
175- @pytest .mark .xfail (raises = ValueError )
176175def test_spawn_densitydist_bound_method ():
177176 with pm .Model () as model :
178177 mu = pm .Normal ("mu" , 0 , 1 )
179178 normal_dist = pm .Normal .dist (mu , 1 )
180179 obs = pm .DensityDist ("density_dist" , normal_dist .logp , observed = np .random .randn (100 ))
181- trace = pm .sample (draws = 10 , tune = 10 , step = pm .Metropolis (), cores = 2 , mp_ctx = "spawn" )
180+ msg = "logp for DensityDist is a bound method, leading to RecursionError while serializing"
181+ with pytest .raises (ValueError , match = msg ):
182+ trace = pm .sample (draws = 10 , tune = 10 , step = pm .Metropolis (), cores = 2 , mp_ctx = "spawn" )
182183
183184
184- # cannot test this properly: monkeypatching sys.platform messes up Theano
185- # def test_spawn_densitydist_syswarning(monkeypatch):
186- # monkeypatch.setattr(sys, "platform", "win32")
187- # with pm.Model() as model:
188- # mu = pm.Normal('mu', 0, 1)
189- # normal_dist = pm.Normal.dist(mu, 1)
190- # with pytest.warns(UserWarning) as w:
191- # obs = pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100))
192- # assert len(w) == 1 and "errors when sampling on platforms" in w[0].message.args[0]
185+ def test_spawn_densitydist_syswarning (monkeypatch ):
186+ monkeypatch .setattr ("pymc3.distributions.distribution.PLATFORM" , "win32" )
187+ with pm .Model () as model :
188+ mu = pm .Normal ("mu" , 0 , 1 )
189+ normal_dist = pm .Normal .dist (mu , 1 )
190+ with pytest .warns (UserWarning , match = "errors when sampling on platforms" ):
191+ obs = pm .DensityDist ("density_dist" , normal_dist .logp , observed = np .random .randn (100 ))
193192
194193
195194def test_spawn_densitydist_mpctxwarning (monkeypatch ):
@@ -198,6 +197,5 @@ def test_spawn_densitydist_mpctxwarning(monkeypatch):
198197 with pm .Model () as model :
199198 mu = pm .Normal ("mu" , 0 , 1 )
200199 normal_dist = pm .Normal .dist (mu , 1 )
201- with pytest .warns (UserWarning ) as w :
200+ with pytest .warns (UserWarning , match = "errors when sampling when multiprocessing" ) :
202201 obs = pm .DensityDist ("density_dist" , normal_dist .logp , observed = np .random .randn (100 ))
203- assert len (w ) == 1 and "errors when sampling when multiprocessing" in w [0 ].message .args [0 ]
0 commit comments