Skip to content

Commit bfabe0d

Browse files
Merge branch 'main' into implement-pmx.fit-option-for-INLA-+-marginalisation-routine
2 parents 0ebfaf8 + 3531d29 commit bfabe0d

File tree

33 files changed

+222
-453
lines changed

33 files changed

+222
-453
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- tests/statespace/filters/test_kalman_filter.py
3333
- tests/statespace --ignore tests/statespace/core/test_statespace.py --ignore tests/statespace/filters/test_kalman_filter.py
3434
- tests/distributions
35-
- tests --ignore tests/model --ignore tests/statespace --ignore tests/distributions
35+
- tests --ignore tests/model --ignore tests/statespace --ignore tests/distributions --ignore tests/pathfinder
3636
fail-fast: false
3737
runs-on: ${{ matrix.os }}
3838
env:

conda-envs/environment-test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
name: pymc-extras-test
1+
name: pymc-extras
22
channels:
33
- conda-forge
44
- nodefaults
55
dependencies:
6-
- pymc>=5.24.1
7-
- pytensor>=2.31.4
6+
- pymc>=5.26.1
7+
- pytensor>=2.35.1
88
- scikit-learn
99
- better-optimize>=0.1.5
1010
- dask<2025.1.1

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def find_MAP(
168168
jitter_rvs: list[TensorVariable] | None = None,
169169
progressbar: bool = True,
170170
include_transformed: bool = True,
171+
freeze_model: bool = True,
171172
gradient_backend: GradientBackend = "pytensor",
172173
compile_kwargs: dict | None = None,
173174
compute_hessian: bool = False,
@@ -210,6 +211,10 @@ def find_MAP(
210211
Whether to display a progress bar during optimization. Defaults to True.
211212
include_transformed: bool, optional
212213
Whether to include transformed variable values in the returned dictionary. Defaults to True.
214+
freeze_model: bool, optional
215+
If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
216+
sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
217+
True.
213218
gradient_backend: str, default "pytensor"
214219
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
215220
compute_hessian: bool
@@ -229,11 +234,13 @@ def find_MAP(
229234
Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed
230235
latent variables, and optimizer results.
231236
"""
232-
model = pm.modelcontext(model) if model is None else model
233-
frozen_model = freeze_dims_and_data(model)
234237
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
238+
model = pm.modelcontext(model) if model is None else model
235239

236-
initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
240+
if freeze_model:
241+
model = freeze_dims_and_data(model)
242+
243+
initial_params = _make_initial_point(model, initvals, random_seed, jitter_rvs)
237244

238245
do_basinhopping = method == "basinhopping"
239246
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
@@ -251,8 +258,8 @@ def find_MAP(
251258
)
252259

253260
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
254-
loss=-frozen_model.logp(),
255-
inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
261+
loss=-model.logp(),
262+
inputs=model.continuous_value_vars + model.discrete_value_vars,
256263
initial_point_dict=DictToArrayBijection.rmap(initial_params),
257264
use_grad=use_grad,
258265
use_hess=use_hess,
@@ -316,12 +323,10 @@ def find_MAP(
316323
}
317324

318325
idata = map_results_to_inference_data(
319-
map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
326+
map_point=optimized_point, model=model, include_transformed=include_transformed
320327
)
321328

322-
idata = add_fit_to_inference_data(
323-
idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
324-
)
329+
idata = add_fit_to_inference_data(idata=idata, mu=raveled_optimized, H_inv=H_inv, model=model)
325330

326331
idata = add_optimizer_result_to_inference_data(
327332
idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,13 @@ def _unconstrained_vector_to_constrained_rvs(model):
6767
unconstrained_vector.name = "unconstrained_vector"
6868

6969
# Redo the names list to ensure it is sorted to match the return order
70-
names = [*constrained_names, *unconstrained_names]
70+
constrained_rvs_and_names = [(rv, name) for rv, name in zip(constrained_rvs, constrained_names)]
71+
value_rvs_and_names = [
72+
(rv, name) for rv, name in zip(value_rvs, names) for name in unconstrained_names
73+
]
74+
# names = [*constrained_names, *unconstrained_names]
7175

72-
return names, constrained_rvs, value_rvs, unconstrained_vector
76+
return constrained_rvs_and_names, value_rvs_and_names, unconstrained_vector
7377

7478

7579
def model_to_laplace_approx(
@@ -81,8 +85,11 @@ def model_to_laplace_approx(
8185

8286
# temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov,
8387
# so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved.
84-
names, constrained_rvs, value_rvs, unconstrained_vector = (
85-
_unconstrained_vector_to_constrained_rvs(model)
88+
89+
# The model was frozen during the find_MAP procedure. To ensure we're operating on the same model, freeze it again.
90+
frozen_model = freeze_dims_and_data(model)
91+
constrained_rvs_and_names, _, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(
92+
frozen_model
8693
)
8794

8895
coords = model.coords | {
@@ -103,12 +110,13 @@ def model_to_laplace_approx(
103110
)
104111

105112
cast_to_var = partial(type_cast, Variable)
113+
constrained_rvs, constrained_names = zip(*constrained_rvs_and_names)
106114
batched_rvs = vectorize_graph(
107115
type_cast(list[Variable], constrained_rvs),
108116
replace={cast_to_var(unconstrained_vector): cast_to_var(laplace_approximation)},
109117
)
110118

111-
for name, batched_rv in zip(names, batched_rvs):
119+
for name, batched_rv in zip(constrained_names, batched_rvs):
112120
batch_dims = ("temp_chain", "temp_draw")
113121
if batched_rv.ndim == 2:
114122
dims = batch_dims
@@ -184,6 +192,7 @@ def fit_laplace(
184192
jitter_rvs: list[pt.TensorVariable] | None = None,
185193
progressbar: bool = True,
186194
include_transformed: bool = True,
195+
freeze_model: bool = True,
187196
gradient_backend: GradientBackend = "pytensor",
188197
chains: int = 2,
189198
draws: int = 500,
@@ -227,6 +236,10 @@ def fit_laplace(
227236
include_transformed: bool, default True
228237
Whether to include transformed variables in the output. If True, transformed variables will be included in the
229238
output InferenceData object. If False, only the original variables will be included.
239+
freeze_model: bool, optional
240+
If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
241+
sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
242+
True.
230243
gradient_backend: str, default "pytensor"
231244
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
232245
chains: int, default: 2
@@ -275,6 +288,9 @@ def fit_laplace(
275288
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
276289
model = pm.modelcontext(model) if model is None else model
277290

291+
if freeze_model:
292+
model = freeze_dims_and_data(model)
293+
278294
idata = find_MAP(
279295
method=optimize_method,
280296
model=model,
@@ -286,6 +302,7 @@ def fit_laplace(
286302
jitter_rvs=jitter_rvs,
287303
progressbar=progressbar,
288304
include_transformed=include_transformed,
305+
freeze_model=False,
289306
gradient_backend=gradient_backend,
290307
compile_kwargs=compile_kwargs,
291308
compute_hessian=True,

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from collections.abc import Callable, Iterator
2323
from dataclasses import asdict, dataclass, field, replace
2424
from enum import Enum, auto
25-
from typing import Literal, TypeAlias
25+
from typing import Literal, Self, TypeAlias
2626

2727
import arviz as az
2828
import filelock
@@ -60,9 +60,6 @@
6060
from rich.table import Table
6161
from rich.text import Text
6262

63-
# TODO: change to typing.Self after Python versions greater than 3.10
64-
from typing_extensions import Self
65-
6663
from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data
6764
from pymc_extras.inference.pathfinder.importance_sampling import (
6865
importance_sampling as _importance_sampling,
@@ -533,7 +530,7 @@ def bfgs_sample_sparse(
533530

534531
# qr_input: (L, N, 2J)
535532
qr_input = inv_sqrt_alpha_diag @ beta
536-
(Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False)
533+
(Q, R), _ = pytensor.scan(fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False)
537534

538535
IdN = pt.eye(R.shape[1])[None, ...]
539536
IdN += IdN * REGULARISATION_TERM

pymc_extras/model/marginal/graph_analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from pymc import SymbolicRandomVariable
77
from pymc.model.fgraph import ModelVar
88
from pymc.variational.minibatch_rv import MinibatchRandomVariable
9-
from pytensor.graph import Variable, ancestors
10-
from pytensor.graph.traversal import io_toposort
9+
from pytensor.graph.basic import Variable
10+
from pytensor.graph.traversal import ancestors, io_toposort
1111
from pytensor.tensor import TensorType, TensorVariable
1212
from pytensor.tensor.blockwise import Blockwise
1313
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise

pymc_extras/model_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ def load(cls, fname: str):
446446
sampler_config=json.loads(idata.attrs["sampler_config"]),
447447
)
448448
model.idata = idata
449+
model.is_fitted_ = True
449450
dataset = idata.fit_data.to_dataframe()
450451
X = dataset.drop(columns=[model.output_var])
451452
y = dataset[model.output_var]
@@ -526,6 +527,8 @@ def fit(
526527
)
527528
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
528529

530+
self.is_fitted_ = True
531+
529532
return self.idata # type: ignore
530533

531534
def predict(

pymc_extras/statespace/core/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def compile_statespace(
2828
x0, P0, c, d, T, Z, R, H, Q, steps=steps, sequence_names=sequence_names
2929
)
3030

31-
inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
31+
inputs = list(pytensor.graph.traversal.explicit_graph_inputs(outputs))
3232

3333
_f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
3434

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def build_graph(
200200
self.n_endog = Z_shape[-2]
201201

202202
data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
203-
203+
data = pt.specify_shape(data, (data.type.shape[0], self.n_endog))
204204
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
205205
params, PARAM_NAMES
206206
)
@@ -658,7 +658,7 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):
658658
# Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
659659
# [0, L_pred]]
660660
# The Schur decomposition of this matrix will be B (upper triangular). We are
661-
# more insterested in B^T:
661+
# more interested in B^T:
662662
# Structure of B^T = [[chol(F), 0 ],
663663
# [K @ chol(F), chol(P_filtered)]
664664
zeros = pt.zeros((self.n_states, self.n_endog))

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import pytensor
22
import pytensor.tensor as pt
33

4-
from pytensor.tensor.nlinalg import matrix_dot
5-
64
from pymc_extras.statespace.filters.utilities import (
75
quad_form_sym,
86
split_vars_into_seq_and_nonseq,
@@ -105,7 +103,7 @@ def smoother_step(self, *args):
105103
a_hat, P_hat = self.predict(a, P, T, R, Q)
106104

107105
# Use pinv, otherwise P_hat is singular when there is missing data
108-
smoother_gain = matrix_dot(pt.linalg.pinv(P_hat, hermitian=True), T, P).T
106+
smoother_gain = (pt.linalg.pinv(P_hat, hermitian=True) @ T @ P).mT
109107
a_smooth_next = a + smoother_gain @ (a_smooth - a_hat)
110108

111109
P_smooth_next = P + quad_form_sym(smoother_gain, P_smooth - P_hat)

0 commit comments

Comments
 (0)