@@ -67,9 +67,13 @@ def _unconstrained_vector_to_constrained_rvs(model):
6767 unconstrained_vector .name = "unconstrained_vector"
6868
6969 # Redo the names list to ensure it is sorted to match the return order
70- names = [* constrained_names , * unconstrained_names ]
70+ constrained_rvs_and_names = [(rv , name ) for rv , name in zip (constrained_rvs , constrained_names )]
71+ value_rvs_and_names = [
72+ (rv , name ) for rv , name in zip (value_rvs , names ) for name in unconstrained_names
73+ ]
74+ # names = [*constrained_names, *unconstrained_names]
7175
72- return names , constrained_rvs , value_rvs , unconstrained_vector
76+ return constrained_rvs_and_names , value_rvs_and_names , unconstrained_vector
7377
7478
7579def model_to_laplace_approx (
@@ -81,8 +85,11 @@ def model_to_laplace_approx(
8185
8286 # temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov,
8387 # so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved.
84- names , constrained_rvs , value_rvs , unconstrained_vector = (
85- _unconstrained_vector_to_constrained_rvs (model )
88+
89+ # The model was frozen during the find_MAP procedure. To ensure we're operating on the same model, freeze it again.
90+ frozen_model = freeze_dims_and_data (model )
91+ constrained_rvs_and_names , _ , unconstrained_vector = _unconstrained_vector_to_constrained_rvs (
92+ frozen_model
8693 )
8794
8895 coords = model .coords | {
@@ -103,12 +110,13 @@ def model_to_laplace_approx(
103110 )
104111
105112 cast_to_var = partial (type_cast , Variable )
113+ constrained_rvs , constrained_names = zip (* constrained_rvs_and_names )
106114 batched_rvs = vectorize_graph (
107115 type_cast (list [Variable ], constrained_rvs ),
108116 replace = {cast_to_var (unconstrained_vector ): cast_to_var (laplace_approximation )},
109117 )
110118
111- for name , batched_rv in zip (names , batched_rvs ):
119+ for name , batched_rv in zip (constrained_names , batched_rvs ):
112120 batch_dims = ("temp_chain" , "temp_draw" )
113121 if batched_rv .ndim == 2 :
114122 dims = batch_dims
@@ -184,6 +192,7 @@ def fit_laplace(
184192 jitter_rvs : list [pt .TensorVariable ] | None = None ,
185193 progressbar : bool = True ,
186194 include_transformed : bool = True ,
195+ freeze_model : bool = True ,
187196 gradient_backend : GradientBackend = "pytensor" ,
188197 chains : int = 2 ,
189198 draws : int = 500 ,
@@ -227,6 +236,10 @@ def fit_laplace(
227236 include_transformed: bool, default True
228237 Whether to include transformed variables in the output. If True, transformed variables will be included in the
229238 output InferenceData object. If False, only the original variables will be included.
239+ freeze_model: bool, optional
240+ If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
241+ sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
242+ True.
230243 gradient_backend: str, default "pytensor"
231244 The backend to use for gradient computations. Must be one of "pytensor" or "jax".
232245 chains: int, default: 2
@@ -275,6 +288,9 @@ def fit_laplace(
275288 optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
276289 model = pm .modelcontext (model ) if model is None else model
277290
291+ if freeze_model :
292+ model = freeze_dims_and_data (model )
293+
278294 idata = find_MAP (
279295 method = optimize_method ,
280296 model = model ,
@@ -286,6 +302,7 @@ def fit_laplace(
286302 jitter_rvs = jitter_rvs ,
287303 progressbar = progressbar ,
288304 include_transformed = include_transformed ,
305+ freeze_model = False ,
289306 gradient_backend = gradient_backend ,
290307 compile_kwargs = compile_kwargs ,
291308 compute_hessian = True ,
0 commit comments