Skip to content

Commit 81086e7

Browse files
committed
allow inference data object to automatically set dims and post process the dimension names after
1 parent 5abe4ff commit 81086e7

File tree

1 file changed

+57
-9
lines changed

1 file changed

+57
-9
lines changed

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import logging
17+
import re
1718

1819
from collections.abc import Callable
1920
from functools import partial
@@ -51,6 +52,58 @@
5152
_log = logging.getLogger(__name__)
5253

5354

55+
def _reset_laplace_dim_idx(idata: az.InferenceData) -> az.InferenceData:
56+
"""
57+
Because `fit_laplace` adds the (temp_chain, temp_draw) dimensions,
58+
any variables without explicitly assigned dimensions receive
59+
automatically generated indices that are shifted by two during
60+
InferenceData creation.
61+
62+
This helper function corrects that shift by subtracting 2 from the
63+
automatically detected dimension indices of the form
64+
`<varname>_dim_<idx>`, restoring them to the indices they would have
65+
had if the (temp_chain, temp_draw) dimensions were not added.
66+
67+
Only affects auto-assigned dimensions in `idata.posterior`.
68+
"""
69+
70+
pattern = re.compile(r"^(?P<base>.+)_dim_(?P<idx>\d+)$")
71+
72+
dim_renames = {}
73+
var_renames = {}
74+
75+
for dim in idata.posterior.dims:
76+
match = pattern.match(dim)
77+
if match is None:
78+
continue
79+
80+
base = match.group("base")
81+
idx = int(match.group("idx"))
82+
83+
# Guard against invalid or unintended renames
84+
if idx < 2:
85+
raise ValueError(
86+
f"Cannot reset Laplace dimension index for '{dim}': "
87+
f"index {idx} would become negative."
88+
)
89+
90+
new_dim = f"{base}_dim_{idx - 2}"
91+
92+
dim_renames[dim] = new_dim
93+
94+
# Only rename variables if they actually exist
95+
if dim in idata.posterior.variables:
96+
var_renames[dim] = new_dim
97+
98+
if dim_renames:
99+
idata.posterior = idata.posterior.rename_dims(dim_renames)
100+
101+
if var_renames:
102+
idata.posterior = idata.posterior.rename_vars(var_renames)
103+
104+
return idata
105+
106+
54107
def get_conditional_gaussian_approximation(
55108
x: TensorVariable,
56109
Q: TensorVariable | ArrayLike,
@@ -224,15 +277,8 @@ def model_to_laplace_approx(
224277
elif name in model.named_vars_to_dims:
225278
dims = (*batch_dims, *model.named_vars_to_dims[name])
226279
else:
227-
initval = initial_point.get(name, None)
228-
dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:]
229-
if dim_shapes[0] is not None:
230-
dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
231-
laplace_model.add_coords(
232-
{name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
233-
)
234-
else:
235-
dims = None
280+
n_dim = batched_rv.ndim - 2 # (temp_chain, temp_draw) are always first 2 dims
281+
dims = (*batch_dims,) + (None,) * n_dim
236282

237283
pm.Deterministic(name, batched_rv, dims=dims)
238284

@@ -471,4 +517,6 @@ def fit_laplace(
471517
["laplace_approximation", "unpacked_variable_names"]
472518
)
473519

520+
idata = _reset_laplace_dim_idx(idata)
521+
474522
return idata

0 commit comments

Comments
 (0)