Skip to content

Commit 7f32a48

Browse files
author
Jesse Grabowski
committed
Iterate on proposal
1 parent a4dacd8 commit 7f32a48

File tree

2 files changed

+295
-0
lines changed

2 files changed

+295
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from collections.abc import Iterator
2+
from dataclasses import dataclass, fields
3+
from typing import Generic, Self, TypeVar
4+
5+
from pymc_extras.statespace.core import PyMCStateSpace
6+
from pymc_extras.statespace.models.structural.core import Component
7+
from pymc_extras.statespace.utils.constants import (
8+
ALL_STATE_AUX_DIM,
9+
ALL_STATE_DIM,
10+
OBS_STATE_AUX_DIM,
11+
OBS_STATE_DIM,
12+
SHOCK_AUX_DIM,
13+
SHOCK_DIM,
14+
)
15+
16+
17+
@dataclass(frozen=True)
18+
class Property:
19+
def __str__(self) -> str:
20+
return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self))
21+
22+
23+
T = TypeVar("T", bound=Property)
24+
25+
26+
@dataclass(frozen=True)
27+
class Info(Generic[T]):
28+
items: tuple[T, ...]
29+
key_field: str = "name"
30+
_index: dict[str, T] | None = None
31+
32+
def __post_init__(self):
33+
index = {}
34+
missing_attr = []
35+
for item in self.items:
36+
if not hasattr(item, self.key_field):
37+
missing_attr.append(item)
38+
continue
39+
key = getattr(item, self.key_field)
40+
if key in index:
41+
raise ValueError(f"Duplicate {self.key_field} '{key}' detected.")
42+
index[key] = item
43+
if missing_attr:
44+
raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}")
45+
object.__setattr__(self, "_index", index)
46+
47+
def _key(self, item: T) -> str:
48+
return getattr(item, self.key_field)
49+
50+
def get(self, key: str, default=None) -> T | None:
51+
return self._index.get(key, default)
52+
53+
def __getitem__(self, key: str) -> T:
54+
try:
55+
return self._index[key]
56+
except KeyError as e:
57+
available = ", ".join(self._index.keys())
58+
raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e
59+
60+
def __contains__(self, key: object) -> bool:
61+
return key in self._index
62+
63+
def __iter__(self) -> Iterator[str]:
64+
return iter(self._index)
65+
66+
def __len__(self) -> int:
67+
return len(self._index)
68+
69+
def __str__(self) -> str:
70+
return f"{self.key_field}s: {list(self._index.keys())}"
71+
72+
@property
73+
def names(self) -> tuple[str, ...]:
74+
return tuple(self._index.keys())
75+
76+
77+
@dataclass(frozen=True)
78+
class Parameter(Property):
79+
name: str
80+
shape: tuple[int, ...]
81+
dims: tuple[str, ...]
82+
constraints: str | None = None
83+
84+
85+
@dataclass(frozen=True)
86+
class ParameterInfo(Info[Parameter]):
87+
def __init__(self, parameters: list[Parameter]):
88+
super().__init__(items=tuple(parameters), key_field="name")
89+
90+
91+
@dataclass(frozen=True)
92+
class Data(Property):
93+
name: str
94+
shape: tuple[int, ...]
95+
dims: tuple[str, ...]
96+
is_exogenous: bool
97+
98+
99+
@dataclass(frozen=True)
100+
class DataInfo(Info[Data]):
101+
def __init__(self, data: list[Data]):
102+
super().__init__(items=tuple(data), key_field="name")
103+
104+
@property
105+
def needs_exogenous_data(self) -> bool:
106+
return any(d.is_exogenous for d in self.items)
107+
108+
def __str__(self) -> str:
109+
return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"
110+
111+
112+
@dataclass(frozen=True)
113+
class Coord(Property):
114+
dimension: str
115+
labels: tuple[str, ...]
116+
117+
118+
@dataclass(frozen=True)
119+
class CoordInfo(Info[Coord]):
120+
def __init__(self, coords: list[Coord]):
121+
super().__init__(items=tuple(coords), key_field="dimension")
122+
123+
def __str__(self) -> str:
124+
base = "coordinates:"
125+
for coord in self.items:
126+
coord_str = str(coord)
127+
indented = "\n".join(" " + line for line in coord_str.splitlines())
128+
base += "\n" + indented + "\n"
129+
return base
130+
131+
@classmethod
132+
def default_coords_from_model(cls, model: Component | PyMCStateSpace) -> Self:
133+
states = tuple(model.state_names)
134+
obs_states = tuple(model.observed_state_names)
135+
shocks = tuple(model.shock_names)
136+
137+
dim_to_labels = (
138+
(ALL_STATE_DIM, states),
139+
(ALL_STATE_AUX_DIM, states),
140+
(OBS_STATE_DIM, obs_states),
141+
(OBS_STATE_AUX_DIM, obs_states),
142+
(SHOCK_DIM, shocks),
143+
(SHOCK_AUX_DIM, shocks),
144+
)
145+
146+
coords = [Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels]
147+
return cls(coords)
148+
149+
150+
@dataclass(frozen=True)
151+
class State(Property):
152+
name: str
153+
observed: bool
154+
shared: bool
155+
156+
157+
@dataclass(frozen=True)
158+
class StateInfo(Info[State]):
159+
def __init__(self, states: list[State]):
160+
super().__init__(items=tuple(states), key_field="name")
161+
162+
def __str__(self) -> str:
163+
return (
164+
f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}"
165+
)
166+
167+
168+
@dataclass(frozen=True)
169+
class Shock(Property):
170+
name: str
171+
172+
173+
@dataclass(frozen=True)
174+
class ShockInfo(Info[Shock]):
175+
def __init__(self, shocks: list[Shock]):
176+
super().__init__(items=tuple(shocks), key_field="name")
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import pytest
2+
3+
from pymc_extras.statespace.core.properties import (
4+
CoordInfo,
5+
Data,
6+
DataInfo,
7+
Parameter,
8+
ParameterInfo,
9+
Shock,
10+
ShockInfo,
11+
State,
12+
StateInfo,
13+
)
14+
from pymc_extras.statespace.utils.constants import (
15+
ALL_STATE_AUX_DIM,
16+
ALL_STATE_DIM,
17+
OBS_STATE_AUX_DIM,
18+
OBS_STATE_DIM,
19+
SHOCK_AUX_DIM,
20+
SHOCK_DIM,
21+
)
22+
23+
24+
def test_property_str_formats_fields():
25+
p = Parameter(name="alpha", shape=(2,), dims=("param",))
26+
s = str(p).splitlines()
27+
assert s == [
28+
"name: alpha",
29+
"shape: (2,)",
30+
"dims: ('param',)",
31+
"constraints: None",
32+
]
33+
34+
35+
def test_info_lookup_contains_and_missing_key():
36+
params = [
37+
Parameter("a", (1,), ("d",)),
38+
Parameter("b", (2,), ("d",)),
39+
Parameter("c", (3,), ("d",)),
40+
]
41+
info = ParameterInfo(params)
42+
43+
assert info.get("b").name == "b"
44+
assert info["a"].shape == (1,)
45+
assert "c" in info
46+
47+
with pytest.raises(KeyError) as e:
48+
_ = info["missing"]
49+
assert "No name 'missing'" in str(e.value)
50+
51+
52+
def test_data_info_needs_exogenous_and_str():
53+
data = [
54+
Data("price", (10,), ("time",), is_exogenous=False),
55+
Data("x", (10,), ("time",), is_exogenous=True),
56+
]
57+
info = DataInfo(data)
58+
59+
assert info.needs_exogenous_data is True
60+
s = str(info)
61+
assert "data: ['price', 'x']" in s
62+
assert "needs exogenous data: True" in s
63+
64+
no_exog = DataInfo([Data("y", (10,), ("time",), is_exogenous=False)])
65+
assert no_exog.needs_exogenous_data is False
66+
67+
68+
def test_coord_info_make_defaults_from_component_and_types():
69+
class DummyComponent:
70+
state_names = ["x1", "x2"]
71+
observed_state_names = ["x2"]
72+
shock_names = ["eps1"]
73+
74+
ci = CoordInfo.default_coords_from_model(DummyComponent())
75+
76+
expected = [
77+
(ALL_STATE_DIM, ("x1", "x2")),
78+
(ALL_STATE_AUX_DIM, ("x1", "x2")),
79+
(OBS_STATE_DIM, ("x2",)),
80+
(OBS_STATE_AUX_DIM, ("x2",)),
81+
(SHOCK_DIM, ("eps1",)),
82+
(SHOCK_AUX_DIM, ("eps1",)),
83+
]
84+
85+
assert len(ci.items) == 6
86+
for dim, labels in expected:
87+
assert dim in ci
88+
assert ci[dim].labels == labels
89+
assert isinstance(ci[dim].labels, tuple)
90+
91+
92+
def test_state_info_and_shockinfo_basic():
93+
states = [
94+
State("x1", observed=True, shared=False),
95+
State("x2", observed=False, shared=True),
96+
]
97+
state_info = StateInfo(states)
98+
assert state_info["x1"].observed is True
99+
s = str(state_info)
100+
101+
assert "states: ['x1', 'x2']" in s
102+
assert "observed: [True, False]" in s
103+
104+
shocks = [Shock("s1"), Shock("s2")]
105+
shock_info = ShockInfo(shocks)
106+
107+
assert "s1" in shock_info
108+
assert shock_info["s2"].name == "s2"
109+
110+
111+
def test_info_is_iterable_and_unpackable():
112+
items = [Parameter("p1", (1,), ("d",)), Parameter("p2", (2,), ("d",))]
113+
info = ParameterInfo(items)
114+
115+
names = info.names
116+
assert names == ("p1", "p2")
117+
118+
a, b = info.items
119+
assert a.name == "p1" and b.name == "p2"

0 commit comments

Comments
 (0)