|
14 | 14 |
|
15 | 15 |
|
16 | 16 | import logging |
| 17 | +import re |
17 | 18 |
|
18 | 19 | from collections.abc import Callable |
19 | 20 | from functools import partial |
|
51 | 52 | _log = logging.getLogger(__name__) |
52 | 53 |
|
53 | 54 |
|
| 55 | +def _reset_laplace_dim_idx(idata: az.InferenceData) -> az.InferenceData: |
| 56 | + """ |
| 57 | + Because `fit_laplace` adds the (temp_chain, temp_draw) dimensions, |
| 58 | + any variables without explicitly assigned dimensions receive |
| 59 | + automatically generated indices that are shifted by two during |
| 60 | + InferenceData creation. |
| 61 | +
|
| 62 | + This helper function corrects that shift by subtracting 2 from the |
| 63 | + automatically detected dimension indices of the form |
| 64 | + `<varname>_dim_<idx>`, restoring them to the indices they would have |
| 65 | + had if the (temp_chain, temp_draw) dimensions were not added. |
| 66 | +
|
| 67 | + Only affects auto-assigned dimensions in `idata.posterior`. |
| 68 | + """ |
| 69 | + |
| 70 | + pattern = re.compile(r"^(?P<base>.+)_dim_(?P<idx>\d+)$") |
| 71 | + |
| 72 | + dim_renames = {} |
| 73 | + var_renames = {} |
| 74 | + |
| 75 | + for dim in idata.posterior.dims: |
| 76 | + match = pattern.match(dim) |
| 77 | + if match is None: |
| 78 | + continue |
| 79 | + |
| 80 | + base = match.group("base") |
| 81 | + idx = int(match.group("idx")) |
| 82 | + |
| 83 | + # Guard against invalid or unintended renames |
| 84 | + if idx < 2: |
| 85 | + raise ValueError( |
| 86 | + f"Cannot reset Laplace dimension index for '{dim}': " |
| 87 | + f"index {idx} would become negative." |
| 88 | + ) |
| 89 | + |
| 90 | + new_dim = f"{base}_dim_{idx - 2}" |
| 91 | + |
| 92 | + dim_renames[dim] = new_dim |
| 93 | + |
| 94 | + # Only rename variables if they actually exist |
| 95 | + if dim in idata.posterior.variables: |
| 96 | + var_renames[dim] = new_dim |
| 97 | + |
| 98 | + if dim_renames: |
| 99 | + idata.posterior = idata.posterior.rename_dims(dim_renames) |
| 100 | + |
| 101 | + if var_renames: |
| 102 | + idata.posterior = idata.posterior.rename_vars(var_renames) |
| 103 | + |
| 104 | + return idata |
| 105 | + |
| 106 | + |
54 | 107 | def get_conditional_gaussian_approximation( |
55 | 108 | x: TensorVariable, |
56 | 109 | Q: TensorVariable | ArrayLike, |
@@ -224,12 +277,8 @@ def model_to_laplace_approx( |
224 | 277 | elif name in model.named_vars_to_dims: |
225 | 278 | dims = (*batch_dims, *model.named_vars_to_dims[name]) |
226 | 279 | else: |
227 | | - dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) |
228 | | - initval = initial_point.get(name, None) |
229 | | - dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:] |
230 | | - laplace_model.add_coords( |
231 | | - {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} |
232 | | - ) |
| 280 | + n_dim = batched_rv.ndim - 2 # (temp_chain, temp_draw) are always first 2 dims |
| 281 | + dims = (*batch_dims,) + (None,) * n_dim |
233 | 282 |
|
234 | 283 | pm.Deterministic(name, batched_rv, dims=dims) |
235 | 284 |
|
@@ -468,4 +517,6 @@ def fit_laplace( |
468 | 517 | ["laplace_approximation", "unpacked_variable_names"] |
469 | 518 | ) |
470 | 519 |
|
| 520 | + idata = _reset_laplace_dim_idx(idata) |
| 521 | + |
471 | 522 | return idata |
0 commit comments