Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 60 additions & 26 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ class PyMCStateSpace:
Regardless of whether a mode is specified, it can always be overwritten via the ``compile_kwargs`` argument
to all sampling methods.

name : str, optional
Prefix used to namespace internal graph variable and data names so multiple state space models can coexist
in the same PyMC model without naming collisions. If ``None``, the default naming behavior is used.

Notes
-----
Based on the statsmodels statespace implementation https://github.com/statsmodels/statsmodels/blob/main/statsmodels/tsa/statespace/representation.py,
Expand Down Expand Up @@ -261,6 +265,8 @@ def __init__(
verbose: bool = True,
measurement_error: bool = False,
mode: str | None = None,
name: str | None = None,
data_name: str = "data",
):
self._fit_coords: dict[str, Sequence[str]] | None = None
self._fit_dims: dict[str, Sequence[str]] | None = None
Expand All @@ -274,6 +280,8 @@ def __init__(
self.k_endog = k_endog
self.k_states = k_states
self.k_posdef = k_posdef
self.name = name
self.data_name = data_name
self.measurement_error = measurement_error
self.mode = mode

Expand Down Expand Up @@ -305,6 +313,12 @@ def __init__(
console = Console()
console.print(self.requirement_table)

def prefixed_name(self, base_name: str) -> str:
if not self.name:
return base_name
prefix = f"{self.name}_"
return base_name if base_name.startswith(prefix) else f"{self.name}_{base_name}"

def _populate_properties(self) -> None:
self._set_parameters()
self._set_states()
Expand Down Expand Up @@ -614,7 +628,7 @@ def add_default_priors(self) -> None:
raise NotImplementedError("The add_default_priors property has not been implemented!")

def make_and_register_variable(
self, name, shape: int | tuple[int, ...] | None = None, dtype=floatX
self, base_name, shape: int | tuple[int, ...] | None = None, dtype=floatX
) -> pt.TensorVariable:
"""
Helper function to create a pytensor symbolic variable and register it in the _name_to_variable dictionary
Expand Down Expand Up @@ -643,12 +657,14 @@ def make_and_register_variable(
An error is raised if the provided name has already been registered, or if the name is not present in the
``param_names`` property.
"""
if name not in self.param_names:
if base_name not in self.param_names:
raise ValueError(
f"{name} is not a model parameter. All placeholder variables should correspond to model "
f"{base_name} is not a model parameter. All placeholder variables should correspond to model "
f"parameters."
)

name = self.prefixed_name(base_name)

if name in self._tensor_variable_info:
raise ValueError(
f"{name} is already a registered placeholder variable with shape "
Expand All @@ -661,7 +677,7 @@ def make_and_register_variable(
return placeholder

def make_and_register_data(
self, name: str, shape: int | tuple[int], dtype: str = floatX
self, base_name: str, shape: int | tuple[int], dtype: str = floatX
) -> Variable:
r"""
Helper function to create a pytensor symbolic variable and register it in the _name_to_data dictionary
Expand All @@ -683,12 +699,14 @@ def make_and_register_data(
An error is raised if the provided name has already been registered, or if the name is not present in the
``data_names`` property.
"""
if name not in self.data_names:
if base_name not in self.data_names:
raise ValueError(
f"{name} is not a model parameter. All placeholder variables should correspond to model "
f"parameters."
f"{base_name} is not a model data variable. All placeholder variables should correspond to model "
f"data variables."
)

name = self.prefixed_name(base_name)

if name in self._tensor_data_info:
raise ValueError(
f"{name} is already a registered placeholder variable with shape "
Expand Down Expand Up @@ -800,11 +818,12 @@ def _save_exogenous_data_info(self):
"""
pymc_mod = modelcontext(None)
for data_name in self.data_names:
data = pymc_mod[data_name]
name = self.prefixed_name(data_name)
data = pymc_mod[name]
self._fit_exog_data[data_name] = {
"name": data_name,
"name": name,
"value": data.get_value(),
"dims": pymc_mod.named_vars_to_dims.get(data_name, None),
"dims": pymc_mod.named_vars_to_dims.get(name, None),
}

def _insert_random_variables(self):
Expand Down Expand Up @@ -843,9 +862,10 @@ def _insert_random_variables(self):
found_params = []
with pymc_model:
for param_name in self.param_names:
param = getattr(pymc_model, param_name, None)
name = self.prefixed_name(param_name)
param = getattr(pymc_model, name, None)
if param is not None:
found_params.append(param.name)
found_params.append(param_name)

missing_params = list(set(self.param_names) - set(found_params))
if len(missing_params) > 0:
Expand Down Expand Up @@ -880,9 +900,10 @@ def _insert_data_variables(self):
found_data = []
with pymc_model:
for data_name in data_names:
data = getattr(pymc_model, data_name, None)
name = self.prefixed_name(data_name)
data = getattr(pymc_model, name, None)
if data is not None:
found_data.append(data.name)
found_data.append(data_name)

missing_data = list(set(data_names) - set(found_data))
if len(missing_data) > 0:
Expand Down Expand Up @@ -1046,6 +1067,7 @@ def build_statespace_graph(
obs_coords=obs_coords,
register_data=register_data,
missing_fill_value=missing_fill_value,
data_name=self.prefixed_name(self.data_name),
)

filter_outputs = self.kalman_filter.build_graph(
Expand Down Expand Up @@ -1144,15 +1166,16 @@ def _build_dummy_graph(self) -> None:
A list of pm.Flat variables representing all parameters estimated by the model.
"""

def infer_variable_shape(name):
def infer_variable_shape(base_name):
name = self.prefixed_name(base_name)
shape = self._name_to_variable[name].type.shape
if not any(dim is None for dim in shape):
return shape

dim_names = self._fit_dims.get(name, None)
if dim_names is None:
raise ValueError(
f"Could not infer shape for {name}, because it was not given coords during model"
f"Could not infer shape for {base_name}, because it was not given coords during model"
f"fitting"
)

Expand All @@ -1164,11 +1187,11 @@ def infer_variable_shape(name):
]
)

for name in self.param_names:
for base_name in self.param_names:
pm.Flat(
name,
shape=infer_variable_shape(name),
dims=self._fit_dims.get(name, None),
self.prefixed_name(base_name),
shape=infer_variable_shape(base_name),
dims=self._fit_dims.get(self.prefixed_name(base_name), None),
)

def _kalman_filter_outputs_from_dummy_graph(
Expand Down Expand Up @@ -1208,14 +1231,14 @@ def _kalman_filter_outputs_from_dummy_graph(
self._insert_random_variables()

for name in self.data_names:
if name not in pm_mod:
if self.prefixed_name(name) not in pm_mod:
pm.Data(**self._fit_exog_data[name])

self._insert_data_variables()

for name in self.data_names:
if name in scenario.keys():
pm.set_data({name: scenario[name]})
pm.set_data({self.prefixed_name(name): scenario[name]})

x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()

Expand All @@ -1230,6 +1253,7 @@ def _kalman_filter_outputs_from_dummy_graph(
obs_coords=obs_coords,
data_dims=data_dims,
register_data=True,
data_name=self.prefixed_name(self.data_name),
)

filter_outputs = self.kalman_filter.build_graph(
Expand Down Expand Up @@ -1786,7 +1810,7 @@ def sample_statespace_matrices(
self._insert_random_variables()

for name in self.data_names:
pm.Data(**self.data_info[name])
pm.Data(name=self.prefixed_name(name), **self.data_info[name])

self._insert_data_variables()
matrices = self.unpack_statespace()
Expand Down Expand Up @@ -1852,6 +1876,7 @@ def sample_filter_outputs(
n_obs=self.ssm.k_endog,
obs_coords=obs_coords,
register_data=True,
data_name=self.prefixed_name(self.data_name),
)

filter_outputs = self.kalman_filter.build_graph(
Expand Down Expand Up @@ -2283,12 +2308,18 @@ def _build_forecast_model(
mu, cov = grouped_outputs[group_idx]

sub_dict = {
data_var: pt.as_tensor_variable(data_var.get_value(), name="data")
data_var: pt.as_tensor_variable(
data_var.get_value(), name=self.prefixed_name(self.data_name)
)
for data_var in forecast_model.data_vars
}

missing_data_vars = np.setdiff1d(
ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
ar1=[
*[self.prefixed_name(name) for name in self.data_names],
self.prefixed_name(self.data_name),
],
ar2=[k.name for k, _ in sub_dict.items()],
)
if missing_data_vars.size > 0:
raise ValueError(f"{missing_data_vars} data used for fitting not found!")
Expand Down Expand Up @@ -2466,8 +2497,11 @@ def forecast(
with forecast_model:
if scenario is not None:
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
scoped_scenario = {
self.prefixed_name(name): value for name, value in scenario.items()
}
pm.set_data(
scenario | {"data": dummy_obs_data},
scoped_scenario | {self.prefixed_name(self.data_name): dummy_obs_data},
coords={"data_time": np.arange(len(forecast_index))},
)

Expand Down
16 changes: 11 additions & 5 deletions pymc_extras/statespace/utils/data_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals
return preprocess_numpy_data(data.values, n_obs, obs_coords)


def add_data_to_active_model(values, index, data_dims=None):
def add_data_to_active_model(values, index, data_dims=None, data_name="data"):
pymc_mod = modelcontext(None)
if data_dims is None:
data_dims = [TIME_DIM, OBS_STATE_DIM]
Expand All @@ -146,7 +146,7 @@ def add_data_to_active_model(values, index, data_dims=None):
else:
data_shape = (None, values.shape[-1])

data = pm.Data("data", values, dims=data_dims, shape=data_shape)
data = pm.Data(data_name, values, dims=data_dims, shape=data_shape)

return data

Expand Down Expand Up @@ -178,7 +178,13 @@ def mask_missing_values_in_data(values, missing_fill_value=None):


def register_data_with_pymc(
data, n_obs, obs_coords, register_data=True, missing_fill_value=None, data_dims=None
data,
n_obs,
obs_coords,
register_data=True,
missing_fill_value=None,
data_dims=None,
data_name="data",
):
if isinstance(data, pt.TensorVariable | TensorSharedVariable):
values, index = preprocess_tensor_data(data, n_obs, obs_coords)
Expand All @@ -192,7 +198,7 @@ def register_data_with_pymc(
data, nan_mask = mask_missing_values_in_data(values, missing_fill_value)

if register_data:
data = add_data_to_active_model(data, index, data_dims)
data = add_data_to_active_model(data, index, data_dims, data_name=data_name)
else:
data = pytensor.shared(data, name="data")
data = pytensor.shared(data, name=data_name)
return data, nan_mask
10 changes: 10 additions & 0 deletions tests/statespace/core/test_statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,16 @@ def test_base_class_raises():
)


def test_two_statespace_models_can_coexist_with_names(monkeypatch):
monkeypatch.setattr(PyMCStateSpace, "make_symbolic_graph", lambda self: None)

with pm.Model():
ssm_a = PyMCStateSpace(k_endog=1, k_states=1, k_posdef=1, name="a")
ssm_b = PyMCStateSpace(k_endog=1, k_states=1, k_posdef=1, name="b")

assert ssm_a.prefixed_name("data") != ssm_b.prefixed_name("data")


def test_update_raises_if_missing_variables(ss_mod):
with pm.Model() as mod:
rho = pm.Normal("rho")
Expand Down