Skip to content

Commit 1a31bb3

Browse files
committed
Allow ellipsis in specify_shape helper
1 parent 499c3d4 commit 1a31bb3

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

pytensor/tensor/shape.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Sequence
33
from numbers import Number
44
from textwrap import dedent
5+
from types import EllipsisType
56
from typing import TYPE_CHECKING, Union, cast
67
from typing import cast as typing_cast
78

@@ -27,7 +28,7 @@
2728
if TYPE_CHECKING:
2829
from pytensor.tensor import TensorLike
2930

30-
ShapeValueType = None | np.integer | int | Variable
31+
ShapeValueType = None | EllipsisType | np.integer | int | Variable
3132

3233

3334
def register_shape_c_code(type, code, version=()):
@@ -549,26 +550,37 @@ def specify_shape(
549550
550551
If a dimension's shape value is ``None``, the size of that dimension is not
551552
considered fixed/static at runtime.
553+
554+
A single ``Ellipsis`` can be used to imply multiple ``None`` specified dimensions
552555
"""
556+
x = as_tensor_variable(x) # type: ignore[arg-type]
553557

554558
if not isinstance(shape, tuple | list):
555559
shape = (shape,)
556560

557561
# If shape is a symbolic 1d vector of fixed length, we separate the items into a
558562
# tuple with one entry per shape dimension
559-
if len(shape) == 1 and shape[0] is not None:
560-
shape_vector = ptb.as_tensor_variable(shape[0])
563+
if len(shape) == 1 and shape[0] not in (None, Ellipsis):
564+
shape_vector = ptb.as_tensor_variable(shape[0]) # type: ignore[arg-type]
561565
if shape_vector.ndim == 1:
562566
try:
563567
shape = tuple(shape_vector)
564568
except ValueError:
565569
raise ValueError("Shape vector must have fixed dimensions")
566570

571+
if Ellipsis in shape:
572+
ellipsis_pos = shape.index(Ellipsis)
573+
implied_none = x.type.ndim - (len(shape) - 1)
574+
shape = (
575+
*shape[:ellipsis_pos],
576+
*((None,) * implied_none),
577+
*shape[ellipsis_pos + 1 :],
578+
)
579+
if Ellipsis in shape[ellipsis_pos + 1 :]:
580+
raise ValueError("Multiple Ellipsis in specify_shape")
581+
567582
# If the specified shape is already encoded in the input static shape, do nothing
568583
# This ignores PyTensor constants in shape
569-
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
570-
# The above is a type error in Python 3.9 but not 3.12.
571-
# Thus we need to ignore unused-ignore on 3.12.
572584
new_shape_info = any(
573585
s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None
574586
)

tests/tensor/test_shape.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,33 @@ def test_fixed_partial_shapes(self):
480480
y = specify_shape(x, (None, 5))
481481
assert y.type.shape == (3, 5)
482482

483+
def test_ellipsis(self):
484+
x = tensor("x", shape=(None, None, None, None))
485+
486+
y = specify_shape(x, ...)
487+
assert y.type.shape == (None, None, None, None)
488+
489+
y = specify_shape(x, (...,))
490+
assert y.type.shape == (None, None, None, None)
491+
492+
y = specify_shape(x, (..., 5))
493+
assert y.type.shape == (None, None, None, 5)
494+
495+
y = specify_shape(x, (5, ...))
496+
assert y.type.shape == (5, None, None, None)
497+
498+
y = specify_shape(x, (5, ..., 3))
499+
assert y.type.shape == (5, None, None, 3)
500+
501+
y = specify_shape(x, (5, ..., 3, None))
502+
assert y.type.shape == (5, None, 3, None)
503+
504+
y = specify_shape(x, (5, 1, ..., 3, None))
505+
assert y.type.shape == (5, 1, 3, None)
506+
507+
with pytest.raises(ValueError, match="Multiple Ellipsis in specify_shape"):
508+
specify_shape(x, (..., None, ...))
509+
483510
def test_python_perform(self):
484511
"""Test the Python `Op.perform` implementation."""
485512
x = scalar()
@@ -583,6 +610,8 @@ def test_direct_return(self):
583610

584611
assert specify_shape(x, (1, 2, None)) is x
585612
assert specify_shape(x, (None, None, None)) is x
613+
assert specify_shape(x, (...,)) is x
614+
assert specify_shape(x, (..., None)) is x
586615

587616
assert specify_shape(x, (1, 2, 3)) is not x
588617
assert specify_shape(x, (None, None, 3)) is not x

0 commit comments

Comments
 (0)