Skip to content

Commit 415b042

Browse files
committed
Refactor shape_from_dims usage to Model method
1 parent c068382 commit 415b042

File tree

2 files changed

+2
-31
lines changed

2 files changed

+2
-31
lines changed

pymc/distributions/distribution.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
convert_size,
4949
find_size,
5050
rv_size_is_none,
51-
shape_from_dims,
5251
)
5352
from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob
5453
from pymc.logprob.basic import logp
@@ -522,7 +521,7 @@ def __new__(
522521
# finally, observed, to determine the shape of the variable.
523522
if kwargs.get("size") is None and kwargs.get("shape") is None:
524523
if dims is not None:
525-
kwargs["shape"] = shape_from_dims(dims, model)
524+
kwargs["shape"] = model.shape_from_dims(dims)
526525
elif observed is not None:
527526
kwargs["shape"] = tuple(observed.shape)
528527

pymc/distributions/shape_utils.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from pymc.util import StrongDims, StrongShape
4141

4242
if TYPE_CHECKING:
43-
from pymc.model import Model
43+
pass
4444
Shape = int | TensorVariable | Sequence[int | Variable]
4545
Dims = str | Sequence[str | None]
4646
DimsWithEllipsis = str | EllipsisType | Sequence[str | None | EllipsisType]
@@ -164,34 +164,6 @@ def convert_size(size: Size) -> StrongSize | None:
164164
)
165165

166166

167-
def shape_from_dims(dims: StrongDims, model: Model) -> StrongShape:
168-
"""Determine shape from a `dims` tuple.
169-
170-
Parameters
171-
----------
172-
dims : array-like
173-
A vector of dimension names or None.
174-
model : pm.Model
175-
The current model on stack.
176-
177-
Returns
178-
-------
179-
shape : tuple
180-
Shape inferred from model dimension lengths.
181-
"""
182-
if model is None:
183-
raise ValueError("model must be provided explicitly to infer shape from dims")
184-
185-
# Dims must be known already
186-
unknowndim_dims = set(dims) - set(model.dim_lengths)
187-
if unknowndim_dims:
188-
raise KeyError(
189-
f"Dimensions {unknowndim_dims} are unknown to the model and cannot be used to specify a `shape`."
190-
)
191-
192-
return tuple(model.dim_lengths[dname] for dname in dims)
193-
194-
195167
def find_size(
196168
shape: StrongShape | None,
197169
size: StrongSize | None,

0 commit comments

Comments
 (0)