Skip to content

Commit 0dfb495

Browse files
committed
Swap arguments in recover_marginals
Closes #610
1 parent 10d6765 commit 0dfb495

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

pymc_extras/model/marginal/marginal_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
1212
from pymc.distributions.transforms import Chain
1313
from pymc.logprob.transforms import IntervalTransform
14-
from pymc.model import Model
14+
from pymc.model import Model, modelcontext
1515
from pymc.model.fgraph import (
1616
ModelFreeRV,
1717
ModelValuedVar,
@@ -337,8 +337,8 @@ def transform_posterior_pts(model, posterior_pts):
337337

338338

339339
def recover_marginals(
340-
model: Model,
341340
idata: InferenceData,
341+
model: Model | None = None,
342342
var_names: Sequence[str] | None = None,
343343
return_samples: bool = True,
344344
extend_inferencedata: bool = True,
@@ -389,6 +389,11 @@ def recover_marginals(
389389
390390
391391
"""
392+
if isinstance(idata, Model):
393+
raise TypeError("The first argument of `recover_marginals` must be an idata")
394+
395+
model = modelcontext(model)
396+
392397
unmarginal_model = unmarginalize(model)
393398

394399
# Find the names of the marginalized variables

tests/model/marginal/test_marginal_model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,9 @@ def test_basic(self):
837837
)
838838
idata = InferenceData(posterior=dict_to_dataset(prior))
839839

840-
idata = recover_marginals(marginal_m, idata, return_samples=True)
840+
with marginal_m:
841+
idata = recover_marginals(idata, return_samples=True)
842+
841843
post = idata.posterior
842844
assert "k" in post
843845
assert "lp_k" in post
@@ -881,7 +883,8 @@ def test_coords(self):
881883
posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
882884
)
883885

884-
idata = recover_marginals(marginal_m, idata, return_samples=True)
886+
with marginal_m:
887+
idata = recover_marginals(idata, return_samples=True)
885888
post = idata.posterior
886889
assert post.idx.dims == ("chain", "draw", "year")
887890
assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim")
@@ -907,7 +910,7 @@ def test_batched(self):
907910
posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
908911
)
909912

910-
idata = recover_marginals(marginal_m, idata, return_samples=True)
913+
idata = recover_marginals(idata, return_samples=True)
911914
post = idata.posterior
912915
assert post["y"].shape == (1, 20, 2, 3)
913916
assert post["idx"].shape == (1, 20, 3, 2)
@@ -933,7 +936,7 @@ def test_nested(self):
933936
)
934937
idata = InferenceData(posterior=dict_to_dataset(prior))
935938

936-
idata = recover_marginals(marginal_m, idata, return_samples=True)
939+
idata = recover_marginals(idata, return_samples=True)
937940
post = idata.posterior
938941
assert "idx" in post
939942
assert "lp_idx" in post

0 commit comments

Comments
 (0)