Skip to content

Commit a337b96

Browse files
Merge branch 'pymc-devs:main' into implement-pmx.fit-option-for-INLA-+-marginalisation-routine
2 parents 176ca6b + 250e81a commit a337b96

File tree

33 files changed

+3706
-876
lines changed

33 files changed

+3706
-876
lines changed

CONTRIBUTING.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,24 @@
11
# Contributing guide
22

33
Page in construction, for now go to https://github.com/pymc-devs/pymc-extras#questions.
4+
5+
## Building the documentation
6+
7+
To build the documentation locally, you need to install the necessary
8+
dependencies and then use `make` to build the HTML files.
9+
10+
First, install the package with the optional documentation dependencies:
11+
12+
```bash
13+
pip install ".[docs]"
14+
```
15+
16+
Then, navigate to the `docs` directory and run `make html`:
17+
18+
```bash
19+
cd docs
20+
make html
21+
```
22+
23+
The generated HTML files will be in the `docs/_build/html` directory. You can
24+
open the `index.html` file in that directory to view the documentation.

docs/api_reference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Prior
5656
create_dim_handler
5757
handle_dims
5858
Prior
59+
register_tensor_transform
5960
VariableFactory
6061
sample_prior
6162
Censored

docs/statespace/models.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ Statespace Models
66
.. autosummary::
77
:toctree: generated
88

9-
BayesianSARIMA
9+
BayesianETS
10+
BayesianSARIMAX
1011
BayesianVARMAX
1112

1213
*********************

notebooks/deterministic_advi_example.ipynb

Lines changed: 1609 additions & 425 deletions
Large diffs are not rendered by default.

pymc_extras/deserialize.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
1414
from pymc_extras.deserialize import deserialize
1515
16-
prior_class_data = {
17-
"dist": "Normal",
18-
"kwargs": {"mu": 0, "sigma": 1}
19-
}
16+
prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
2017
prior = deserialize(prior_class_data)
2118
# Prior("Normal", mu=0, sigma=1)
2219
@@ -26,6 +23,7 @@
2623
2724
from pymc_extras.deserialize import register_deserialization
2825
26+
2927
class MyClass:
3028
def __init__(self, value: int):
3129
self.value = value
@@ -34,6 +32,7 @@ def to_dict(self) -> dict:
3432
# Example of what the to_dict method might look like.
3533
return {"value": self.value}
3634
35+
3736
register_deserialization(
3837
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
3938
deserialize=lambda data: MyClass(value=data["value"]),
@@ -80,18 +79,23 @@ class Deserializer:
8079
8180
from typing import Any
8281
82+
8383
class MyClass:
8484
def __init__(self, value: int):
8585
self.value = value
8686
87+
8788
from pymc_extras.deserialize import Deserializer
8889
90+
8991
def is_type(data: Any) -> bool:
9092
return data.keys() == {"value"} and isinstance(data["value"], int)
9193
94+
9295
def deserialize(data: dict) -> MyClass:
9396
return MyClass(value=data["value"])
9497
98+
9599
deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
96100
97101
"""
@@ -196,6 +200,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
196200
197201
from pymc_extras.deserialize import register_deserialization
198202
203+
199204
class MyClass:
200205
def __init__(self, value: int):
201206
self.value = value
@@ -204,6 +209,7 @@ def to_dict(self) -> dict:
204209
# Example of what the to_dict method might look like.
205210
return {"value": self.value}
206211
212+
207213
register_deserialization(
208214
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
209215
deserialize=lambda data: MyClass(value=data["value"]),

pymc_extras/distributions/continuous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ class Chi:
265265
from pymc_extras.distributions import Chi
266266
267267
with pm.Model():
268-
x = Chi('x', nu=1)
268+
x = Chi("x", nu=1)
269269
"""
270270

271271
@staticmethod

pymc_extras/distributions/histogram_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
130130
... m = pm.Normal("m", dims="tests")
131131
... s = pm.LogNormal("s", dims="tests")
132132
... pot = pmx.distributions.histogram_approximation(
133-
... "pot", pm.Normal.dist(m, s),
134-
... observed=measurements, n_quantiles=50
133+
... "pot", pm.Normal.dist(m, s), observed=measurements, n_quantiles=50
135134
... )
136135
137136
For special cases like Zero Inflation in Continuous variables there is a flag.
@@ -143,8 +142,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
143142
... m = pm.Normal("m", dims="tests")
144143
... s = pm.LogNormal("s", dims="tests")
145144
... pot = pmx.distributions.histogram_approximation(
146-
... "pot", pm.Normal.dist(m, s),
147-
... observed=measurements, n_quantiles=50, zero_inflation=True
145+
... "pot",
146+
... pm.Normal.dist(m, s),
147+
... observed=measurements,
148+
... n_quantiles=50,
149+
... zero_inflation=True,
148150
... )
149151
"""
150152
try:

pymc_extras/distributions/multivariate/r2d2m2cp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def R2D2M2CP(
305305
import pymc_extras as pmx
306306
import pymc as pm
307307
import numpy as np
308+
308309
X = np.random.randn(10, 3)
309310
b = np.random.randn(3)
310311
y = X @ b + np.random.randn(10) * 0.04 + 5
@@ -339,7 +340,7 @@ def R2D2M2CP(
339340
# "c" - a must have in the relation
340341
variables_importance=[10, 1, 34],
341342
# NOTE: try both
342-
centered=True
343+
centered=True,
343344
)
344345
# intercept prior centering should be around prior predictive mean
345346
intercept = y.mean()
@@ -365,7 +366,7 @@ def R2D2M2CP(
365366
r2_std=0.2,
366367
# NOTE: if you know where a variable should go
367368
# if you do not know, leave as 0.5
368-
centered=False
369+
centered=False,
369370
)
370371
# intercept prior centering should be around prior predictive mean
371372
intercept = y.mean()
@@ -394,7 +395,7 @@ def R2D2M2CP(
394395
# if you do not know, leave as 0.5
395396
positive_probs=[0.8, 0.5, 0.1],
396397
# NOTE: try both
397-
centered=True
398+
centered=True,
398399
)
399400
intercept = y.mean()
400401
obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)

pymc_extras/distributions/timeseries.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,10 @@ class DiscreteMarkovChain(Distribution):
113113
114114
with pm.Model() as markov_chain:
115115
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
116-
init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
117-
markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
116+
init_dist = pm.Categorical.dist(p=np.full(3, 1 / 3))
117+
markov_chain = pmx.DiscreteMarkovChain(
118+
"markov_chain", P=P, init_dist=init_dist, shape=(100,)
119+
)
118120
119121
"""
120122

0 commit comments

Comments
 (0)