Skip to content

Commit ca005f3

Browse files
committed
Merge upstream/main into main
2 parents 29b6ec0 + 3531d29 commit ca005f3

36 files changed

+227
-461
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- tests/statespace/filters/test_kalman_filter.py
3333
- tests/statespace --ignore tests/statespace/core/test_statespace.py --ignore tests/statespace/filters/test_kalman_filter.py
3434
- tests/distributions
35-
- tests --ignore tests/model --ignore tests/statespace --ignore tests/distributions
35+
- tests --ignore tests/model --ignore tests/statespace --ignore tests/distributions --ignore tests/pathfinder
3636
fail-fast: false
3737
runs-on: ${{ matrix.os }}
3838
env:

conda-envs/environment-test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
name: pymc-extras-test
1+
name: pymc-extras
22
channels:
33
- conda-forge
44
- nodefaults
55
dependencies:
6-
- pymc>=5.24.1
7-
- pytensor>=2.31.4
6+
- pymc>=5.26.1
7+
- pytensor>=2.35.1
88
- scikit-learn
99
- better-optimize>=0.1.5
1010
- dask<2025.1.1

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def find_MAP(
168168
jitter_rvs: list[TensorVariable] | None = None,
169169
progressbar: bool = True,
170170
include_transformed: bool = True,
171+
freeze_model: bool = True,
171172
gradient_backend: GradientBackend = "pytensor",
172173
compile_kwargs: dict | None = None,
173174
compute_hessian: bool = False,
@@ -210,6 +211,10 @@ def find_MAP(
210211
Whether to display a progress bar during optimization. Defaults to True.
211212
include_transformed: bool, optional
212213
Whether to include transformed variable values in the returned dictionary. Defaults to True.
214+
freeze_model: bool, optional
215+
If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
216+
sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
217+
True.
213218
gradient_backend: str, default "pytensor"
214219
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
215220
compute_hessian: bool
@@ -229,11 +234,13 @@ def find_MAP(
229234
Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed
230235
latent variables, and optimizer results.
231236
"""
232-
model = pm.modelcontext(model) if model is None else model
233-
frozen_model = freeze_dims_and_data(model)
234237
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
238+
model = pm.modelcontext(model) if model is None else model
235239

236-
initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
240+
if freeze_model:
241+
model = freeze_dims_and_data(model)
242+
243+
initial_params = _make_initial_point(model, initvals, random_seed, jitter_rvs)
237244

238245
do_basinhopping = method == "basinhopping"
239246
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
@@ -251,8 +258,8 @@ def find_MAP(
251258
)
252259

253260
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
254-
loss=-frozen_model.logp(),
255-
inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
261+
loss=-model.logp(),
262+
inputs=model.continuous_value_vars + model.discrete_value_vars,
256263
initial_point_dict=DictToArrayBijection.rmap(initial_params),
257264
use_grad=use_grad,
258265
use_hess=use_hess,
@@ -316,12 +323,10 @@ def find_MAP(
316323
}
317324

318325
idata = map_results_to_inference_data(
319-
map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
326+
map_point=optimized_point, model=model, include_transformed=include_transformed
320327
)
321328

322-
idata = add_fit_to_inference_data(
323-
idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
324-
)
329+
idata = add_fit_to_inference_data(idata=idata, mu=raveled_optimized, H_inv=H_inv, model=model)
325330

326331
idata = add_optimizer_result_to_inference_data(
327332
idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,13 @@ def _unconstrained_vector_to_constrained_rvs(model):
168168
unconstrained_vector.name = "unconstrained_vector"
169169

170170
# Redo the names list to ensure it is sorted to match the return order
171-
names = [*constrained_names, *unconstrained_names]
171+
constrained_rvs_and_names = [(rv, name) for rv, name in zip(constrained_rvs, constrained_names)]
172+
value_rvs_and_names = [
173+
(rv, name) for rv, name in zip(value_rvs, names) for name in unconstrained_names
174+
]
175+
# names = [*constrained_names, *unconstrained_names]
172176

173-
return names, constrained_rvs, value_rvs, unconstrained_vector
177+
return constrained_rvs_and_names, value_rvs_and_names, unconstrained_vector
174178

175179

176180
def model_to_laplace_approx(
@@ -182,8 +186,11 @@ def model_to_laplace_approx(
182186

183187
# temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov,
184188
# so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved.
185-
names, constrained_rvs, value_rvs, unconstrained_vector = (
186-
_unconstrained_vector_to_constrained_rvs(model)
189+
190+
# The model was frozen during the find_MAP procedure. To ensure we're operating on the same model, freeze it again.
191+
frozen_model = freeze_dims_and_data(model)
192+
constrained_rvs_and_names, _, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(
193+
frozen_model
187194
)
188195

189196
coords = model.coords | {
@@ -204,12 +211,13 @@ def model_to_laplace_approx(
204211
)
205212

206213
cast_to_var = partial(type_cast, Variable)
214+
constrained_rvs, constrained_names = zip(*constrained_rvs_and_names)
207215
batched_rvs = vectorize_graph(
208216
type_cast(list[Variable], constrained_rvs),
209217
replace={cast_to_var(unconstrained_vector): cast_to_var(laplace_approximation)},
210218
)
211219

212-
for name, batched_rv in zip(names, batched_rvs):
220+
for name, batched_rv in zip(constrained_names, batched_rvs):
213221
batch_dims = ("temp_chain", "temp_draw")
214222
if batched_rv.ndim == 2:
215223
dims = batch_dims
@@ -285,6 +293,7 @@ def fit_laplace(
285293
jitter_rvs: list[pt.TensorVariable] | None = None,
286294
progressbar: bool = True,
287295
include_transformed: bool = True,
296+
freeze_model: bool = True,
288297
gradient_backend: GradientBackend = "pytensor",
289298
chains: int = 2,
290299
draws: int = 500,
@@ -328,6 +337,10 @@ def fit_laplace(
328337
include_transformed: bool, default True
329338
Whether to include transformed variables in the output. If True, transformed variables will be included in the
330339
output InferenceData object. If False, only the original variables will be included.
340+
freeze_model: bool, optional
341+
If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
342+
sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
343+
True.
331344
gradient_backend: str, default "pytensor"
332345
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
333346
chains: int, default: 2
@@ -376,6 +389,9 @@ def fit_laplace(
376389
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
377390
model = pm.modelcontext(model) if model is None else model
378391

392+
if freeze_model:
393+
model = freeze_dims_and_data(model)
394+
379395
idata = find_MAP(
380396
method=optimize_method,
381397
model=model,
@@ -387,6 +403,7 @@ def fit_laplace(
387403
jitter_rvs=jitter_rvs,
388404
progressbar=progressbar,
389405
include_transformed=include_transformed,
406+
freeze_model=False,
390407
gradient_backend=gradient_backend,
391408
compile_kwargs=compile_kwargs,
392409
compute_hessian=True,

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from collections.abc import Callable, Iterator
2323
from dataclasses import asdict, dataclass, field, replace
2424
from enum import Enum, auto
25-
from typing import Literal, TypeAlias
25+
from typing import Literal, Self, TypeAlias
2626

2727
import arviz as az
2828
import filelock
@@ -60,9 +60,6 @@
6060
from rich.table import Table
6161
from rich.text import Text
6262

63-
# TODO: change to typing.Self after Python versions greater than 3.10
64-
from typing_extensions import Self
65-
6663
from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data
6764
from pymc_extras.inference.pathfinder.importance_sampling import (
6865
importance_sampling as _importance_sampling,
@@ -533,7 +530,7 @@ def bfgs_sample_sparse(
533530

534531
# qr_input: (L, N, 2J)
535532
qr_input = inv_sqrt_alpha_diag @ beta
536-
(Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False)
533+
(Q, R), _ = pytensor.scan(fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False)
537534

538535
IdN = pt.eye(R.shape[1])[None, ...]
539536
IdN += IdN * REGULARISATION_TERM

pymc_extras/model/marginal/graph_analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from pymc import SymbolicRandomVariable
77
from pymc.model.fgraph import ModelVar
88
from pymc.variational.minibatch_rv import MinibatchRandomVariable
9-
from pytensor.graph import Variable, ancestors
10-
from pytensor.graph.basic import io_toposort
9+
from pytensor.graph.basic import Variable
10+
from pytensor.graph.traversal import ancestors, io_toposort
1111
from pytensor.tensor import TensorType, TensorVariable
1212
from pytensor.tensor.blockwise import Blockwise
1313
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise

pymc_extras/statespace/core/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def compile_statespace(
2828
x0, P0, c, d, T, Z, R, H, Q, steps=steps, sequence_names=sequence_names
2929
)
3030

31-
inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
31+
inputs = list(pytensor.graph.traversal.explicit_graph_inputs(outputs))
3232

3333
_f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
3434

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def build_graph(
200200
self.n_endog = Z_shape[-2]
201201

202202
data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q)
203-
203+
data = pt.specify_shape(data, (data.type.shape[0], self.n_endog))
204204
sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq(
205205
params, PARAM_NAMES
206206
)
@@ -658,7 +658,7 @@ def update(self, a, P, y, d, Z, H, all_nan_flag):
658658
# Construct upper-triangular block matrix A = [[chol(H), Z @ L_pred],
659659
# [0, L_pred]]
660660
# The Schur decomposition of this matrix will be B (upper triangular). We are
661-
# more insterested in B^T:
661+
# more interested in B^T:
662662
# Structure of B^T = [[chol(F), 0 ],
663663
# [K @ chol(F), chol(P_filtered)]
664664
zeros = pt.zeros((self.n_states, self.n_endog))

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import pytensor
22
import pytensor.tensor as pt
33

4-
from pytensor.tensor.nlinalg import matrix_dot
5-
64
from pymc_extras.statespace.filters.utilities import (
75
quad_form_sym,
86
split_vars_into_seq_and_nonseq,
@@ -105,7 +103,7 @@ def smoother_step(self, *args):
105103
a_hat, P_hat = self.predict(a, P, T, R, Q)
106104

107105
# Use pinv, otherwise P_hat is singular when there is missing data
108-
smoother_gain = matrix_dot(pt.linalg.pinv(P_hat, hermitian=True), T, P).T
106+
smoother_gain = (pt.linalg.pinv(P_hat, hermitian=True) @ T @ P).mT
109107
a_smooth_next = a + smoother_gain @ (a_smooth - a_hat)
110108

111109
P_smooth_next = P + quad_form_sym(smoother_gain, P_smooth - P_hat)

pymc_extras/statespace/models/DFM.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytensor.tensor as pt
66

77
from pymc_extras.statespace.core.statespace import PyMCStateSpace
8-
from pymc_extras.statespace.models.utilities import make_default_coords
8+
from pymc_extras.statespace.models.utilities import make_default_coords, validate_names
99
from pymc_extras.statespace.utils.constants import (
1010
ALL_STATE_AUX_DIM,
1111
ALL_STATE_DIM,
@@ -224,9 +224,7 @@ def __init__(
224224
self,
225225
k_factors: int,
226226
factor_order: int,
227-
k_endog: int | None = None,
228227
endog_names: Sequence[str] | None = None,
229-
k_exog: int | None = None,
230228
exog_names: Sequence[str] | None = None,
231229
shared_exog_states: bool = False,
232230
exog_innovations: bool = False,
@@ -249,19 +247,11 @@ def __init__(
249247
and are modeled as a white noise process, i.e., :math:`f_t = \varepsilon_{f,t}`.
250248
Therefore, the state vector will include one state per factor and "factor_ar" will not exist.
251249
252-
k_endog : int, optional
253-
Number of observed time series. If not provided, the number of observed series will be inferred from `endog_names`.
254-
At least one of `k_endog` or `endog_names` must be provided.
255-
256250
endog_names : list of str, optional
257-
Names of the observed time series. If not provided, default names will be generated as `endog_1`, `endog_2`, ..., `endog_k` based on `k_endog`.
258-
At least one of `k_endog` or `endog_names` must be provided.
259-
260-
k_exog : int, optional
261-
Number of exogenous variables. If not provided, the model will not have exogenous variables.
251+
Names of the observed time series.
262252
263253
exog_names : Sequence[str], optional
264-
Names of the exogenous variables. If not provided, but `k_exog` is specified, default names will be generated as `exog_1`, `exog_2`, ..., `exog_k`.
254+
Names of the exogenous variables.
265255
266256
shared_exog_states: bool, optional
267257
Whether exogenous latent states are shared across the observed states. If True, there will be only one set of exogenous latent
@@ -289,13 +279,8 @@ def __init__(
289279
290280
"""
291281

292-
if k_endog is None and endog_names is None:
293-
raise ValueError("Either k_endog or endog_names must be provided.")
294-
if k_endog is None:
295-
k_endog = len(endog_names)
296-
if endog_names is None:
297-
endog_names = [f"endog_{i}" for i in range(k_endog)]
298-
282+
validate_names(endog_names, var_name="endog_names", optional=False)
283+
k_endog = len(endog_names)
299284
self.endog_names = endog_names
300285
self.k_endog = k_endog
301286
self.k_factors = k_factors
@@ -304,17 +289,17 @@ def __init__(
304289
self.error_var = error_var
305290
self.error_cov_type = error_cov_type
306291

307-
if k_exog is None and exog_names is None:
308-
self.k_exog = 0
309-
else:
292+
if exog_names is not None:
310293
self.shared_exog_states = shared_exog_states
311294
self.exog_innovations = exog_innovations
312-
if k_exog is None:
313-
k_exog = len(exog_names) if exog_names is not None else 0
314-
elif exog_names is None:
315-
exog_names = [f"exog_{i}" for i in range(k_exog)] if k_exog > 0 else None
295+
validate_names(
296+
exog_names, var_name="exog_names", optional=True
297+
) # Not sure if this adds anything
298+
k_exog = len(exog_names)
316299
self.k_exog = k_exog
317300
self.exog_names = exog_names
301+
else:
302+
self.k_exog = 0
318303

319304
self.k_exog_states = self.k_exog * self.k_endog if not shared_exog_states else self.k_exog
320305
self.exog_flag = self.k_exog > 0

0 commit comments

Comments
 (0)