1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import logging
15+ import platform
1516import warnings
1617
1718import numpy as np
2829from pymc .smc .kernels import IMH , systematic_resampling
2930from tests .helpers import assert_random_state_equal
3031
32+ _IS_WINDOWS = platform .system () == "Windows"
33+
3134
3235class TestSMC :
3336 """Tests for the default SMC kernel"""
@@ -75,7 +78,9 @@ def two_gaussians(x):
7578 def test_sample (self ):
7679 initial_rng_state = np .random .get_state ()
7780 with self .SMC_test :
78- mtrace = pm .sample_smc (draws = self .samples , return_inferencedata = False )
81+ mtrace = pm .sample_smc (
82+ draws = self .samples , return_inferencedata = False , progressbar = not _IS_WINDOWS
83+ )
7984
8085 # Verify sampling was done with a non-global random generator
8186 assert_random_state_equal (initial_rng_state , np .random .get_state ())
@@ -142,7 +147,9 @@ def test_marginal_likelihood(self):
142147 with pm .Model () as model :
143148 a = pm .Beta ("a" , alpha , beta )
144149 y = pm .Bernoulli ("y" , a , observed = data )
145- trace = pm .sample_smc (2000 , chains = 2 , return_inferencedata = False )
150+ trace = pm .sample_smc (
151+ 2000 , chains = 2 , return_inferencedata = False , progressbar = not _IS_WINDOWS
152+ )
146153 # log_marginal_likelihood is found in the last value of each chain
147154 lml = np .mean ([chain [- 1 ] for chain in trace .report .log_marginal_likelihood ])
148155 marginals .append (lml )
@@ -203,8 +210,15 @@ def test_return_datatype(self, chains):
203210 with warnings .catch_warnings ():
204211 warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
205212 warnings .filterwarnings ("ignore" , "More chains .* than draws .*" , UserWarning )
206- idata = pm .sample_smc (chains = chains , draws = draws )
207- mt = pm .sample_smc (chains = chains , draws = draws , return_inferencedata = False )
213+ idata = pm .sample_smc (
214+ chains = chains , draws = draws , progressbar = not (chains > 1 and _IS_WINDOWS )
215+ )
216+ mt = pm .sample_smc (
217+ chains = chains ,
218+ draws = draws ,
219+ return_inferencedata = False ,
220+ progressbar = not (chains > 1 and _IS_WINDOWS ),
221+ )
208222
209223 assert isinstance (idata , InferenceData )
210224 assert "sample_stats" in idata
@@ -218,7 +232,7 @@ def test_return_datatype(self, chains):
218232 def test_convergence_checks (self , caplog ):
219233 with caplog .at_level (logging .INFO ):
220234 with self .fast_model :
221- pm .sample_smc (draws = 99 )
235+ pm .sample_smc (draws = 99 , progressbar = not _IS_WINDOWS )
222236 assert "The number of samples is too small" in caplog .text
223237
224238 def test_deprecated_parallel_arg (self ):
@@ -265,7 +279,7 @@ def test_normal_model(self):
265279 mu = pm .Normal ("mu" , 0 , 3 )
266280 sigma = pm .HalfNormal ("sigma" , 1 )
267281 y = pm .Normal ("y" , mu , sigma , observed = data )
268- idata = pm .sample_smc (draws = 2000 , kernel = pm .smc .MH )
282+ idata = pm .sample_smc (draws = 2000 , kernel = pm .smc .MH , progressbar = not _IS_WINDOWS )
269283 assert_random_state_equal (initial_rng_state , np .random .get_state ())
270284
271285 post = idata .posterior .stack (sample = ("chain" , "draw" ))
0 commit comments