Skip to content

Commit 10d6765

Browse files
authored
bugfix for fit_laplace absent dims (#609)
1 parent 3531d29 commit 10d6765

File tree

2 files changed

+70
-9
lines changed

2 files changed

+70
-9
lines changed

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import logging
17+
import re
1718

1819
from collections.abc import Callable
1920
from functools import partial
@@ -51,6 +52,58 @@
5152
_log = logging.getLogger(__name__)
5253

5354

55+
def _reset_laplace_dim_idx(idata: az.InferenceData) -> az.InferenceData:
56+
"""
57+
Because `fit_laplace` adds the (temp_chain, temp_draw) dimensions,
58+
any variables without explicitly assigned dimensions receive
59+
automatically generated indices that are shifted by two during
60+
InferenceData creation.
61+
62+
This helper function corrects that shift by subtracting 2 from the
63+
automatically detected dimension indices of the form
64+
`<varname>_dim_<idx>`, restoring them to the indices they would have
65+
had if the (temp_chain, temp_draw) dimensions were not added.
66+
67+
Only affects auto-assigned dimensions in `idata.posterior`.
68+
"""
69+
70+
pattern = re.compile(r"^(?P<base>.+)_dim_(?P<idx>\d+)$")
71+
72+
dim_renames = {}
73+
var_renames = {}
74+
75+
for dim in idata.posterior.dims:
76+
match = pattern.match(dim)
77+
if match is None:
78+
continue
79+
80+
base = match.group("base")
81+
idx = int(match.group("idx"))
82+
83+
# Guard against invalid or unintended renames
84+
if idx < 2:
85+
raise ValueError(
86+
f"Cannot reset Laplace dimension index for '{dim}': "
87+
f"index {idx} would become negative."
88+
)
89+
90+
new_dim = f"{base}_dim_{idx - 2}"
91+
92+
dim_renames[dim] = new_dim
93+
94+
# Only rename variables if they actually exist
95+
if dim in idata.posterior.variables:
96+
var_renames[dim] = new_dim
97+
98+
if dim_renames:
99+
idata.posterior = idata.posterior.rename_dims(dim_renames)
100+
101+
if var_renames:
102+
idata.posterior = idata.posterior.rename_vars(var_renames)
103+
104+
return idata
105+
106+
54107
def get_conditional_gaussian_approximation(
55108
x: TensorVariable,
56109
Q: TensorVariable | ArrayLike,
@@ -224,12 +277,8 @@ def model_to_laplace_approx(
224277
elif name in model.named_vars_to_dims:
225278
dims = (*batch_dims, *model.named_vars_to_dims[name])
226279
else:
227-
dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
228-
initval = initial_point.get(name, None)
229-
dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:]
230-
laplace_model.add_coords(
231-
{name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
232-
)
280+
n_dim = batched_rv.ndim - 2 # (temp_chain, temp_draw) are always first 2 dims
281+
dims = (*batch_dims,) + (None,) * n_dim
233282

234283
pm.Deterministic(name, batched_rv, dims=dims)
235284

@@ -468,4 +517,6 @@ def fit_laplace(
468517
["laplace_approximation", "unpacked_variable_names"]
469518
)
470519

520+
idata = _reset_laplace_dim_idx(idata)
521+
471522
return idata

tests/inference/laplace_approx/test_laplace.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,21 @@ def test_fit_laplace_coords(include_transformed, rng):
153153
assert "city" in idata.unconstrained_posterior.coords
154154

155155

156-
def test_fit_laplace_ragged_coords(rng):
156+
@pytest.mark.parametrize(
157+
"chains, draws, use_dims",
158+
[(1, 500, False), (1, 500, True), (2, 1000, False), (2, 1000, True)],
159+
)
160+
def test_fit_laplace_ragged_coords(chains, draws, use_dims, rng):
157161
coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
158162
with pm.Model(coords=coords) as ragged_dim_model:
159-
X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"])
163+
X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"] if use_dims else None)
160164
beta = pm.Normal(
161165
"beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"]
162166
)
163167
mu = pm.Deterministic(
164-
"mu", (X[:, None, :] * beta[None]).sum(axis=-1), dims=["obs_idx", "city"]
168+
"mu",
169+
(X[:, None, :] * beta[None]).sum(axis=-1),
170+
dims=["obs_idx", "city"] if use_dims else None,
165171
)
166172
sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"])
167173

@@ -178,6 +184,8 @@ def test_fit_laplace_ragged_coords(rng):
178184
progressbar=False,
179185
use_grad=True,
180186
use_hessp=True,
187+
chains=chains,
188+
draws=draws,
181189
)
182190

183191
# These should have been dropped when the laplace idata was created
@@ -186,6 +194,8 @@ def test_fit_laplace_ragged_coords(rng):
186194

187195
assert idata["posterior"].beta.shape[-2:] == (3, 2)
188196
assert idata["posterior"].sigma.shape[-1:] == (3,)
197+
assert idata["posterior"].chain.shape[0] == chains
198+
assert idata["posterior"].draw.shape[0] == draws
189199

190200
# Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1
191201
# strictly positive

0 commit comments

Comments
 (0)