diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 0f887291d..638b1a0ac 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -138,6 +138,10 @@ class PyMCStateSpace: Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument to all sampling methods. + name : str, optional + Prefix used to namespace internal graph variable and data names so multiple state space models can coexist + in the same PyMC model without naming collisions. If ``None``, the default naming behavior is used. + Notes ----- Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py, @@ -261,6 +265,8 @@ def __init__( verbose: bool = True, measurement_error: bool = False, mode: str | None = None, + name: str | None = None, + data_name: str = "data", ): self._fit_coords: dict[str, Sequence[str]] | None = None self._fit_dims: dict[str, Sequence[str]] | None = None @@ -274,6 +280,8 @@ def __init__( self.k_endog = k_endog self.k_states = k_states self.k_posdef = k_posdef + self.name = name + self.data_name = data_name self.measurement_error = measurement_error self.mode = mode @@ -305,6 +313,12 @@ def __init__( console = Console() console.print(self.requirement_table) + def prefixed_name(self, base_name: str) -> str: + if not self.name: + return base_name + prefix = f"{self.name}_" + return base_name if base_name.startswith(prefix) else f"{self.name}_{base_name}" + def _populate_properties(self) -> None: self._set_parameters() self._set_states() @@ -614,7 +628,7 @@ def add_default_priors(self) -> None: raise NotImplementedError("The add_default_priors property has not been implemented!") def make_and_register_variable( - self, name, shape: int | tuple[int, ...] | None = None, dtype=floatX + self, base_name, shape: int | tuple[int, ...] | None = None, dtype=floatX ) -> pt.TensorVariable: """ Helper function to create a pytensor symbolic variable and register it in the _name_to_variable dictionary @@ -643,12 +657,14 @@ def make_and_register_variable( An error is raised if the provided name has already been registered, or if the name is not present in the ``param_names`` property. """ - if name not in self.param_names: + if base_name not in self.param_names: raise ValueError( - f"{name} is not a model parameter. All placeholder variables should correspond to model " + f"{base_name} is not a model parameter. All placeholder variables should correspond to model " f"parameters." ) + name = self.prefixed_name(base_name) + if name in self._tensor_variable_info: raise ValueError( f"{name} is already a registered placeholder variable with shape " @@ -661,7 +677,7 @@ def make_and_register_variable( return placeholder def make_and_register_data( - self, name: str, shape: int | tuple[int], dtype: str = floatX + self, base_name: str, shape: int | tuple[int], dtype: str = floatX ) -> Variable: r""" Helper function to create a pytensor symbolic variable and register it in the _name_to_data dictionary @@ -683,12 +699,14 @@ def make_and_register_data( An error is raised if the provided name has already been registered, or if the name is not present in the ``data_names`` property. """ - if name not in self.data_names: + if base_name not in self.data_names: raise ValueError( - f"{name} is not a model parameter. All placeholder variables should correspond to model " - f"parameters." + f"{base_name} is not a model data variable. All placeholder variables should correspond to model " + f"data variables." ) + name = self.prefixed_name(base_name) + if name in self._tensor_data_info: raise ValueError( f"{name} is already a registered placeholder variable with shape " @@ -800,11 +818,12 @@ def _save_exogenous_data_info(self): """ pymc_mod = modelcontext(None) for data_name in self.data_names: - data = pymc_mod[data_name] + name = self.prefixed_name(data_name) + data = pymc_mod[name] self._fit_exog_data[data_name] = { - "name": data_name, + "name": name, "value": data.get_value(), - "dims": pymc_mod.named_vars_to_dims.get(data_name, None), + "dims": pymc_mod.named_vars_to_dims.get(name, None), } def _insert_random_variables(self): @@ -843,9 +862,10 @@ def _insert_random_variables(self): found_params = [] with pymc_model: for param_name in self.param_names: - param = getattr(pymc_model, param_name, None) + name = self.prefixed_name(param_name) + param = getattr(pymc_model, name, None) if param is not None: - found_params.append(param.name) + found_params.append(param_name) missing_params = list(set(self.param_names) - set(found_params)) if len(missing_params) > 0: @@ -880,9 +900,10 @@ def _insert_data_variables(self): found_data = [] with pymc_model: for data_name in data_names: - data = getattr(pymc_model, data_name, None) + name = self.prefixed_name(data_name) + data = getattr(pymc_model, name, None) if data is not None: - found_data.append(data.name) + found_data.append(data_name) missing_data = list(set(data_names) - set(found_data)) if len(missing_data) > 0: @@ -1046,6 +1067,7 @@ def build_statespace_graph( obs_coords=obs_coords, register_data=register_data, missing_fill_value=missing_fill_value, + data_name=self.prefixed_name(self.data_name), ) filter_outputs = self.kalman_filter.build_graph( @@ -1144,7 +1166,8 @@ def _build_dummy_graph(self) -> None: A list of pm.Flat variables representing all parameters estimated by the model. """ - def infer_variable_shape(name): + def infer_variable_shape(base_name): + name = self.prefixed_name(base_name) shape = self._name_to_variable[name].type.shape if not any(dim is None for dim in shape): return shape @@ -1152,7 +1175,7 @@ def infer_variable_shape(name): dim_names = self._fit_dims.get(name, None) if dim_names is None: raise ValueError( - f"Could not infer shape for {name}, because it was not given coords during model" + f"Could not infer shape for {base_name}, because it was not given coords during model" f"fitting" ) @@ -1164,11 +1187,11 @@ def infer_variable_shape(name): ] ) - for name in self.param_names: + for base_name in self.param_names: pm.Flat( - name, - shape=infer_variable_shape(name), - dims=self._fit_dims.get(name, None), + self.prefixed_name(base_name), + shape=infer_variable_shape(base_name), + dims=self._fit_dims.get(self.prefixed_name(base_name), None), ) def _kalman_filter_outputs_from_dummy_graph( @@ -1208,14 +1231,14 @@ def _kalman_filter_outputs_from_dummy_graph( self._insert_random_variables() for name in self.data_names: - if name not in pm_mod: + if self.prefixed_name(name) not in pm_mod: pm.Data(**self._fit_exog_data[name]) self._insert_data_variables() for name in self.data_names: if name in scenario.keys(): - pm.set_data({name: scenario[name]}) + pm.set_data({self.prefixed_name(name): scenario[name]}) x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace() @@ -1230,6 +1253,7 @@ def _kalman_filter_outputs_from_dummy_graph( obs_coords=obs_coords, data_dims=data_dims, register_data=True, + data_name=self.prefixed_name(self.data_name), ) filter_outputs = self.kalman_filter.build_graph( @@ -1786,7 +1810,7 @@ def sample_statespace_matrices( self._insert_random_variables() for name in self.data_names: - pm.Data(**self.data_info[name]) + pm.Data(name=self.prefixed_name(name), **self.data_info[name]) self._insert_data_variables() matrices = self.unpack_statespace() @@ -1852,6 +1876,7 @@ def sample_filter_outputs( n_obs=self.ssm.k_endog, obs_coords=obs_coords, register_data=True, + data_name=self.prefixed_name(self.data_name), ) filter_outputs = self.kalman_filter.build_graph( @@ -2283,12 +2308,18 @@ def _build_forecast_model( mu, cov = grouped_outputs[group_idx] sub_dict = { - data_var: pt.as_tensor_variable(data_var.get_value(), name="data") + data_var: pt.as_tensor_variable( + data_var.get_value(), name=self.prefixed_name(self.data_name) + ) for data_var in forecast_model.data_vars } missing_data_vars = np.setdiff1d( - ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()] + ar1=[ + *[self.prefixed_name(name) for name in self.data_names], + self.prefixed_name(self.data_name), + ], + ar2=[k.name for k, _ in sub_dict.items()], ) if missing_data_vars.size > 0: raise ValueError(f"{missing_data_vars} data used for fitting not found!") @@ -2466,8 +2497,11 @@ def forecast( with forecast_model: if scenario is not None: dummy_obs_data = np.zeros((len(forecast_index), self.k_endog)) + scoped_scenario = { + self.prefixed_name(name): value for name, value in scenario.items() + } pm.set_data( - scenario | {"data": dummy_obs_data}, + scoped_scenario | {self.prefixed_name(self.data_name): dummy_obs_data}, coords={"data_time": np.arange(len(forecast_index))}, ) diff --git a/pymc_extras/statespace/utils/data_tools.py b/pymc_extras/statespace/utils/data_tools.py index cbc5d517c..204021529 100644 --- a/pymc_extras/statespace/utils/data_tools.py +++ b/pymc_extras/statespace/utils/data_tools.py @@ -121,7 +121,7 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals return preprocess_numpy_data(data.values, n_obs, obs_coords) -def add_data_to_active_model(values, index, data_dims=None): +def add_data_to_active_model(values, index, data_dims=None, data_name="data"): pymc_mod = modelcontext(None) if data_dims is None: data_dims = [TIME_DIM, OBS_STATE_DIM] @@ -146,7 +146,7 @@ def add_data_to_active_model(values, index, data_dims=None): else: data_shape = (None, values.shape[-1]) - data = pm.Data("data", values, dims=data_dims, shape=data_shape) + data = pm.Data(data_name, values, dims=data_dims, shape=data_shape) return data @@ -178,7 +178,13 @@ def mask_missing_values_in_data(values, missing_fill_value=None): def register_data_with_pymc( - data, n_obs, obs_coords, register_data=True, missing_fill_value=None, data_dims=None + data, + n_obs, + obs_coords, + register_data=True, + missing_fill_value=None, + data_dims=None, + data_name="data", ): if isinstance(data, pt.TensorVariable | TensorSharedVariable): values, index = preprocess_tensor_data(data, n_obs, obs_coords) @@ -192,7 +198,7 @@ def register_data_with_pymc( data, nan_mask = mask_missing_values_in_data(values, missing_fill_value) if register_data: - data = add_data_to_active_model(data, index, data_dims) + data = add_data_to_active_model(data, index, data_dims, data_name=data_name) else: - data = pytensor.shared(data, name="data") + data = pytensor.shared(data, name=data_name) return data, nan_mask diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index c8e343609..01dc4a3d3 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -440,6 +440,16 @@ def test_base_class_raises(): ) +def test_two_statespace_models_can_coexist_with_names(monkeypatch): + monkeypatch.setattr(PyMCStateSpace, "make_symbolic_graph", lambda self: None) + + with pm.Model(): + ssm_a = PyMCStateSpace(k_endog=1, k_states=1, k_posdef=1, name="a") + ssm_b = PyMCStateSpace(k_endog=1, k_states=1, k_posdef=1, name="b") + + assert ssm_a.prefixed_name("data") != ssm_b.prefixed_name("data") + + def test_update_raises_if_missing_variables(ss_mod): with pm.Model() as mod: rho = pm.Normal("rho")