Skip to content

Commit b73acbc

Browse files
Feedback
1 parent 5db2bca commit b73acbc

File tree

2 files changed

+39
-65
lines changed

2 files changed

+39
-65
lines changed

pytensor/tensor/rewriting/reshape_ops.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pytensor.graph import node_rewriter
2+
from pytensor.graph.rewriting.basic import copy_stack_trace
23
from pytensor.tensor.rewriting.basic import register_canonicalize
34
from pytensor.tensor.shape_ops import JoinDims, SplitDims
45

@@ -19,7 +20,10 @@ def local_split_dims_to_reshape(fgraph, node):
1920
*x.shape[axis + 1 :],
2021
]
2122

22-
return [x.reshape(output_shape)]
23+
new_x = x.reshape(output_shape)
24+
copy_stack_trace(x, new_x)
25+
26+
return [new_x]
2327

2428

2529
@register_canonicalize
@@ -39,4 +43,7 @@ def local_join_dims_to_reshape(fgraph, node):
3943
*x.shape[start_axis + n_axes :],
4044
]
4145

42-
return [x.reshape(output_shape)]
46+
new_x = x.reshape(output_shape)
47+
48+
copy_stack_trace(x, new_x)
49+
return [new_x]

pytensor/tensor/shape_ops.py

Lines changed: 30 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,11 @@
1111
from pytensor.graph.replace import _vectorize_node
1212
from pytensor.tensor import TensorLike, as_tensor_variable
1313
from pytensor.tensor.basic import (
14-
atleast_1d,
1514
expand_dims,
16-
get_scalar_constant_value,
15+
infer_static_shape,
1716
join,
1817
split,
1918
)
20-
from pytensor.tensor.exceptions import NotScalarConstantError
2119
from pytensor.tensor.math import prod
2220
from pytensor.tensor.shape import ShapeValueType
2321
from pytensor.tensor.type import tensor
@@ -31,20 +29,12 @@ class JoinDims(Op):
3129
)
3230
view_map = {0: [0]}
3331

34-
def __init__(self, input_ndims: int, start_axis: int, n_axes: int):
32+
def __init__(self, start_axis: int, n_axes: int):
3533
if start_axis < 0:
3634
raise ValueError("JoinDims start_axis must be non-negative")
3735

3836
self.start_axis = start_axis
3937
self.n_axes = n_axes
40-
self.input_ndims = input_ndims
41-
42-
output_ndims = 1 if not start_axis else min(1, input_ndims - n_axes)
43-
44-
input_signature = ",".join(f"i{i}" for i in range(input_ndims))
45-
output_signature = ",".join(f"o{i}" for i in range(output_ndims))
46-
47-
self.gufunc_signature = f"({input_signature})->({output_signature})"
4838

4939
@property
5040
def axis_range(self):
@@ -59,11 +49,6 @@ def output_shapes(self, input_shapes, joined_shape):
5949

6050
def make_node(self, x: Variable) -> Apply: # type: ignore[override]
6151
static_shapes = x.type.shape
62-
if x.type.ndim != self.input_ndims:
63-
raise ValueError(
64-
f"Input ndim {x.type.ndim} is not equal to expected ndim {self.input_ndims}"
65-
)
66-
6752
axis_range = self.axis_range
6853

6954
joined_shape = (
@@ -88,13 +73,24 @@ def perform(self, node, inputs, outputs):
8873
(x,) = inputs
8974
(out,) = outputs
9075

91-
output_shape = [
76+
output_shape = (
9277
*x.shape[: self.start_axis],
9378
-1,
9479
*x.shape[self.start_axis + self.n_axes :],
95-
]
80+
)
81+
82+
out[0] = x.reshape(output_shape)
83+
9684

97-
out[0] = x.reshape(tuple(output_shape))
85+
@_vectorize_node.register(JoinDims)
86+
def _vectorize_joindims(op, node, x):
87+
[old_x] = node.inputs
88+
89+
batched_ndims = x.type.ndim - old_x.type.ndim
90+
start_axis = op.start_axis
91+
n_axes = op.n_axes
92+
93+
return JoinDims(start_axis + batched_ndims, n_axes).make_node(x)
9894

9995

10096
def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorVariable:
@@ -129,16 +125,12 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
129125
elif not isinstance(axis, list | tuple):
130126
raise TypeError("axis must be an int, a list/tuple of ints, or None")
131127

132-
if not axis:
133-
# The user passed an empty list/tuple, so we return the input as is
134-
return x
135-
136128
axis = normalize_axis_tuple(axis, x.ndim)
137129

138-
if any(i < 0 for i in axis):
139-
raise ValueError("join_dims axis must be non-negative")
130+
if len(axis) <= 1:
131+
return x
140132

141-
if len(axis) > 1 and np.diff(axis).max() > 1:
133+
if np.diff(axis).max() > 1:
142134
raise ValueError(
143135
f"join_dims axis must be consecutive, got normalized axis: {axis}"
144136
)
@@ -148,7 +140,7 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
148140

149141
return type_cast(
150142
TensorVariable,
151-
JoinDims(input_ndims=x.ndim, start_axis=start_axis, n_axes=n_axes)(x),
143+
JoinDims(start_axis=start_axis, n_axes=n_axes)(x),
152144
)
153145

154146

@@ -162,17 +154,11 @@ def __init__(self, axis: int):
162154
self.axis = axis
163155

164156
def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override]
165-
if shape.type.dtype not in ("int8", "int16", "int32", "int64"):
157+
if shape.type.numpy_dtype.kind not in "iu":
166158
raise TypeError("shape must be an integer tensor")
167159

168-
def _get_constant_shape(x):
169-
try:
170-
return get_scalar_constant_value(x).item()
171-
except NotScalarConstantError:
172-
return x
173-
174160
axis = self.axis
175-
constant_shape = [_get_constant_shape(s) for s in shape] # type: ignore[attr-defined]
161+
_, constant_shape = infer_static_shape(shape)
176162

177163
output_shapes = [
178164
*x.type.shape[:axis],
@@ -181,7 +167,7 @@ def _get_constant_shape(x):
181167
]
182168

183169
output = tensor(
184-
shape=tuple([x if isinstance(x, int) else None for x in output_shapes]),
170+
shape=tuple(x if isinstance(x, int) else None for x in output_shapes),
185171
dtype=x.type.dtype,
186172
)
187173
return Apply(self, [x, shape], [output])
@@ -199,11 +185,7 @@ def perform(self, node, inputs, outputs):
199185
(x, shape) = inputs
200186
(out,) = outputs
201187

202-
output_shape = [
203-
*x.shape[: self.axis],
204-
*shape,
205-
*x.shape[self.axis + 1 :],
206-
]
188+
output_shape = (*x.shape[: self.axis], *shape, *x.shape[self.axis + 1 :])
207189

208190
out[0] = x.reshape(output_shape)
209191

@@ -219,7 +201,7 @@ def _vectorize_splitdims(op, node, x, shape):
219201
return vectorize_node_fallback(op, node, x, shape)
220202

221203
axis = op.axis
222-
return split_dims(x, shape, axis=axis + batched_ndims).owner
204+
return SplitDims(axis=axis + batched_ndims).make_node(x, shape)
223205

224206

225207
def split_dims(
@@ -272,7 +254,7 @@ def split_dims(
272254
return type_cast(TensorVariable, x.squeeze(axis=axis))
273255

274256
[axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc]
275-
shape = as_tensor_variable(shape) # type: ignore[arg-type]
257+
shape = as_tensor_variable(shape, dtype="int64") # type: ignore[arg-type]
276258

277259
split_op = SplitDims(axis=axis)
278260
return type_cast(TensorVariable, split_op(x, shape))
@@ -468,13 +450,6 @@ def pack(
468450
reshaped_tensors: list[TensorVariable] = []
469451
packed_shapes: list[ShapeValueType] = []
470452

471-
if all([n_before == 0, n_after == 0, min_axes == 0]):
472-
# Special case -- we're raveling everything
473-
packed_shapes = [t.shape for t in tensor_list]
474-
reshaped_tensors = [atleast_1d(join_dims(t, None)) for t in tensor_list]
475-
476-
return join(0, *reshaped_tensors), packed_shapes
477-
478453
for i, input_tensor in enumerate(tensor_list):
479454
n_dim = input_tensor.ndim
480455

@@ -488,24 +463,16 @@ def pack(
488463

489464
if n_dim == min_axes:
490465
# If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern
491-
# implied by the axes. If n_before == 0, the reshape would be (-1, ...), so we need to expand at axis 0.
492-
# If n_after == 0, the reshape would be (..., -1), so we need to expand at axis -1. If both are equal,
493-
# the reshape will occur in the center of the tensor.
494-
if n_before == 0:
495-
input_tensor = expand_dims(input_tensor, axis=0)
496-
elif n_after == 0:
497-
input_tensor = expand_dims(input_tensor, axis=-1)
498-
elif n_before == n_after:
499-
input_tensor = expand_dims(input_tensor, axis=n_before)
500-
466+
# implied by the axes.
467+
input_tensor = expand_dims(input_tensor, axis=n_before)
501468
reshaped_tensors.append(input_tensor)
502469
continue
503470

504471
# The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1,
505472
# shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the
506473
# rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the
507474
# corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing).
508-
join_axes = {n_before, n_after_packed - 1}
475+
join_axes = range(n_before, n_after_packed)
509476
joined = join_dims(input_tensor, tuple(join_axes))
510477
reshaped_tensors.append(joined)
511478

@@ -560,7 +527,7 @@ def unpack(
560527

561528
split_inputs = split(
562529
packed_input,
563-
splits_size=[prod(shape).astype(int) for shape in packed_shapes],
530+
splits_size=[prod(shape, dtype=int) for shape in packed_shapes],
564531
n_splits=len(packed_shapes),
565532
axis=split_axis,
566533
)

0 commit comments

Comments
 (0)