Skip to content

Commit 1450629

Browse files
committed
added test for sampling dims and consolidated two tests into one
1 parent 81086e7 commit 1450629

File tree

1 file changed

+13
-41
lines changed

1 file changed

+13
-41
lines changed

tests/inference/laplace_approx/test_laplace.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -153,54 +153,22 @@ def test_fit_laplace_coords(include_transformed, rng):
153153
assert "city" in idata.unconstrained_posterior.coords
154154

155155

156-
def test_fit_laplace_ragged_coords(rng):
156+
@pytest.mark.parametrize(
157+
"chains, draws, use_dims",
158+
[(1, 500, False), (1, 500, True), (2, 1000, False), (2, 1000, True)],
159+
)
160+
def test_fit_laplace_ragged_coords(chains, draws, use_dims, rng):
157161
coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
158162
with pm.Model(coords=coords) as ragged_dim_model:
159-
X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"])
163+
X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"] if use_dims else None)
160164
beta = pm.Normal(
161165
"beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"]
162166
)
163167
mu = pm.Deterministic(
164-
"mu", (X[:, None, :] * beta[None]).sum(axis=-1), dims=["obs_idx", "city"]
165-
)
166-
sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"])
167-
168-
obs = pm.Normal(
169-
"obs",
170-
mu=mu,
171-
sigma=sigma,
172-
observed=rng.normal(loc=3, scale=1.5, size=(100, 3)),
173-
dims=["obs_idx", "city"],
174-
)
175-
176-
idata = fit_laplace(
177-
optimize_method="Newton-CG",
178-
progressbar=False,
179-
use_grad=True,
180-
use_hessp=True,
181-
)
182-
183-
# These should have been dropped when the laplace idata was created
184-
assert "laplace_approximation" not in list(idata.posterior.data_vars.keys())
185-
assert "unpacked_var_names" not in list(idata.posterior.coords.keys())
186-
187-
assert idata["posterior"].beta.shape[-2:] == (3, 2)
188-
assert idata["posterior"].sigma.shape[-1:] == (3,)
189-
190-
# Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1
191-
# strictly positive
192-
assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all()
193-
assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all()
194-
195-
196-
def test_fit_laplace_no_data_or_deterministic_dims(rng):
197-
coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
198-
with pm.Model(coords=coords) as ragged_dim_model:
199-
X = pm.Data("X", np.ones((100, 2)))
200-
beta = pm.Normal(
201-
"beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"]
168+
"mu",
169+
(X[:, None, :] * beta[None]).sum(axis=-1),
170+
dims=["obs_idx", "city"] if use_dims else None,
202171
)
203-
mu = pm.Deterministic("mu", (X[:, None, :] * beta[None]).sum(axis=-1))
204172
sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"])
205173

206174
obs = pm.Normal(
@@ -216,6 +184,8 @@ def test_fit_laplace_no_data_or_deterministic_dims(rng):
216184
progressbar=False,
217185
use_grad=True,
218186
use_hessp=True,
187+
chains=chains,
188+
draws=draws,
219189
)
220190

221191
# These should have been dropped when the laplace idata was created
@@ -224,6 +194,8 @@ def test_fit_laplace_no_data_or_deterministic_dims(rng):
224194

225195
assert idata["posterior"].beta.shape[-2:] == (3, 2)
226196
assert idata["posterior"].sigma.shape[-1:] == (3,)
197+
assert idata["posterior"].chain.shape[0] == chains
198+
assert idata["posterior"].draw.shape[0] == draws
227199

228200
# Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1
229201
# strictly positive

0 commit comments

Comments
 (0)