1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import warnings
16+
1517from contextlib import ExitStack as does_not_raise
1618
1719import aesara
@@ -655,10 +657,10 @@ def mixmixlogp(value, point):
655657 assert_allclose (priorlogp + mixmixlogpg .sum (), model .logp (test_point ), rtol = rtol )
656658
657659 def test_iterable_single_component_warning (self ):
658- with pytest .warns (None ) as record :
660+ with warnings .catch_warnings ():
661+ warnings .simplefilter ("error" )
659662 Mixture .dist (w = [0.5 , 0.5 ], comp_dists = Normal .dist (size = 2 ))
660663 Mixture .dist (w = [0.5 , 0.5 ], comp_dists = [Normal .dist (size = 2 ), Normal .dist (size = 2 )])
661- assert not record
662664
663665 with pytest .warns (UserWarning , match = "Single component will be treated as a mixture" ):
664666 Mixture .dist (w = [0.5 , 0.5 ], comp_dists = [Normal .dist (size = 2 )])
@@ -1303,9 +1305,9 @@ def test_logp(self):
13031305 def test_warning (self ):
13041306 with Model () as m :
13051307 comp_dists = [HalfNormal .dist (), Exponential .dist (1 )]
1306- with pytest .warns (None ) as rec :
1308+ with warnings .catch_warnings ():
1309+ warnings .simplefilter ("error" )
13071310 Mixture ("mix1" , w = [0.5 , 0.5 ], comp_dists = comp_dists )
1308- assert not rec
13091311
13101312 comp_dists = [Uniform .dist (0 , 1 ), Uniform .dist (0 , 2 )]
13111313 with pytest .warns (MixtureTransformWarning ):
@@ -1315,16 +1317,16 @@ def test_warning(self):
13151317 with pytest .warns (MixtureTransformWarning ):
13161318 Mixture ("mix3" , w = [0.5 , 0.5 ], comp_dists = comp_dists )
13171319
1318- with pytest .warns (None ) as rec :
1320+ with warnings .catch_warnings ():
1321+ warnings .simplefilter ("error" )
13191322 Mixture ("mix4" , w = [0.5 , 0.5 ], comp_dists = comp_dists , transform = None )
1320- assert not rec
13211323
1322- with pytest .warns (None ) as rec :
1324+ with warnings .catch_warnings ():
1325+ warnings .simplefilter ("error" )
13231326 Mixture ("mix5" , w = [0.5 , 0.5 ], comp_dists = comp_dists , observed = 1 )
1324- assert not rec
13251327
13261328 # Case where the appropriate default transform is None
13271329 comp_dists = [Normal .dist (), Normal .dist ()]
1328- with pytest .warns (None ) as rec :
1330+ with warnings .catch_warnings ():
1331+ warnings .simplefilter ("error" )
13291332 Mixture ("mix6" , w = [0.5 , 0.5 ], comp_dists = comp_dists )
1330- assert not rec
0 commit comments