1111from pytensor .graph .replace import _vectorize_node
1212from pytensor .tensor import TensorLike , as_tensor_variable
1313from pytensor .tensor .basic import (
14- atleast_1d ,
1514 expand_dims ,
16- get_scalar_constant_value ,
15+ infer_static_shape ,
1716 join ,
1817 split ,
1918)
20- from pytensor .tensor .exceptions import NotScalarConstantError
2119from pytensor .tensor .math import prod
2220from pytensor .tensor .shape import ShapeValueType
2321from pytensor .tensor .type import tensor
@@ -31,20 +29,12 @@ class JoinDims(Op):
3129 )
3230 view_map = {0 : [0 ]}
3331
34- def __init__ (self , input_ndims : int , start_axis : int , n_axes : int ):
32+ def __init__ (self , start_axis : int , n_axes : int ):
3533 if start_axis < 0 :
3634 raise ValueError ("JoinDims start_axis must be non-negative" )
3735
3836 self .start_axis = start_axis
3937 self .n_axes = n_axes
40- self .input_ndims = input_ndims
41-
42- output_ndims = 1 if not start_axis else min (1 , input_ndims - n_axes )
43-
44- input_signature = "," .join (f"i{ i } " for i in range (input_ndims ))
45- output_signature = "," .join (f"o{ i } " for i in range (output_ndims ))
46-
47- self .gufunc_signature = f"({ input_signature } )->({ output_signature } )"
4838
4939 @property
5040 def axis_range (self ):
@@ -59,11 +49,6 @@ def output_shapes(self, input_shapes, joined_shape):
5949
6050 def make_node (self , x : Variable ) -> Apply : # type: ignore[override]
6151 static_shapes = x .type .shape
62- if x .type .ndim != self .input_ndims :
63- raise ValueError (
64- f"Input ndim { x .type .ndim } is not equal to expected ndim { self .input_ndims } "
65- )
66-
6752 axis_range = self .axis_range
6853
6954 joined_shape = (
@@ -88,13 +73,24 @@ def perform(self, node, inputs, outputs):
8873 (x ,) = inputs
8974 (out ,) = outputs
9075
91- output_shape = [
76+ output_shape = (
9277 * x .shape [: self .start_axis ],
9378 - 1 ,
9479 * x .shape [self .start_axis + self .n_axes :],
95- ]
80+ )
81+
82+ out [0 ] = x .reshape (output_shape )
83+
9684
97- out [0 ] = x .reshape (tuple (output_shape ))
85+ @_vectorize_node .register (JoinDims )
86+ def _vectorize_joindims (op , node , x ):
87+ [old_x ] = node .inputs
88+
89+ batched_ndims = x .type .ndim - old_x .type .ndim
90+ start_axis = op .start_axis
91+ n_axes = op .n_axes
92+
93+ return JoinDims (start_axis + batched_ndims , n_axes ).make_node (x )
9894
9995
10096def join_dims (x : TensorLike , axis : Sequence [int ] | int | None = None ) -> TensorVariable :
@@ -129,16 +125,12 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
129125 elif not isinstance (axis , list | tuple ):
130126 raise TypeError ("axis must be an int, a list/tuple of ints, or None" )
131127
132- if not axis :
133- # The user passed an empty list/tuple, so we return the input as is
134- return x
135-
136128 axis = normalize_axis_tuple (axis , x .ndim )
137129
138- if any ( i < 0 for i in axis ) :
139- raise ValueError ( "join_dims axis must be non-negative" )
130+ if len ( axis ) <= 1 :
131+ return x
140132
141- if len ( axis ) > 1 and np .diff (axis ).max () > 1 :
133+ if np .diff (axis ).max () > 1 :
142134 raise ValueError (
143135 f"join_dims axis must be consecutive, got normalized axis: { axis } "
144136 )
@@ -148,7 +140,7 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
148140
149141 return type_cast (
150142 TensorVariable ,
151- JoinDims (input_ndims = x . ndim , start_axis = start_axis , n_axes = n_axes )(x ),
143+ JoinDims (start_axis = start_axis , n_axes = n_axes )(x ),
152144 )
153145
154146
@@ -162,17 +154,11 @@ def __init__(self, axis: int):
162154 self .axis = axis
163155
164156 def make_node (self , x : Variable , shape : Variable ) -> Apply : # type: ignore[override]
165- if shape .type .dtype not in ( "int8" , "int16" , "int32" , "int64" ) :
157+ if shape .type .numpy_dtype . kind not in "iu" :
166158 raise TypeError ("shape must be an integer tensor" )
167159
168- def _get_constant_shape (x ):
169- try :
170- return get_scalar_constant_value (x ).item ()
171- except NotScalarConstantError :
172- return x
173-
174160 axis = self .axis
175- constant_shape = [ _get_constant_shape ( s ) for s in shape ] # type: ignore[attr-defined]
161+ _ , constant_shape = infer_static_shape ( shape )
176162
177163 output_shapes = [
178164 * x .type .shape [:axis ],
@@ -181,7 +167,7 @@ def _get_constant_shape(x):
181167 ]
182168
183169 output = tensor (
184- shape = tuple ([ x if isinstance (x , int ) else None for x in output_shapes ] ),
170+ shape = tuple (x if isinstance (x , int ) else None for x in output_shapes ),
185171 dtype = x .type .dtype ,
186172 )
187173 return Apply (self , [x , shape ], [output ])
@@ -199,11 +185,7 @@ def perform(self, node, inputs, outputs):
199185 (x , shape ) = inputs
200186 (out ,) = outputs
201187
202- output_shape = [
203- * x .shape [: self .axis ],
204- * shape ,
205- * x .shape [self .axis + 1 :],
206- ]
188+ output_shape = (* x .shape [: self .axis ], * shape , * x .shape [self .axis + 1 :])
207189
208190 out [0 ] = x .reshape (output_shape )
209191
@@ -219,7 +201,7 @@ def _vectorize_splitdims(op, node, x, shape):
219201 return vectorize_node_fallback (op , node , x , shape )
220202
221203 axis = op .axis
222- return split_dims ( x , shape , axis = axis + batched_ndims ).owner
204+ return SplitDims ( axis = axis + batched_ndims ).make_node ( x , shape )
223205
224206
225207def split_dims (
@@ -272,7 +254,7 @@ def split_dims(
272254 return type_cast (TensorVariable , x .squeeze (axis = axis ))
273255
274256 [axis ] = normalize_axis_tuple (axis , x .ndim ) # type: ignore[misc]
275- shape = as_tensor_variable (shape ) # type: ignore[arg-type]
257+ shape = as_tensor_variable (shape , dtype = "int64" ) # type: ignore[arg-type]
276258
277259 split_op = SplitDims (axis = axis )
278260 return type_cast (TensorVariable , split_op (x , shape ))
@@ -468,13 +450,6 @@ def pack(
468450 reshaped_tensors : list [TensorVariable ] = []
469451 packed_shapes : list [ShapeValueType ] = []
470452
471- if all ([n_before == 0 , n_after == 0 , min_axes == 0 ]):
472- # Special case -- we're raveling everything
473- packed_shapes = [t .shape for t in tensor_list ]
474- reshaped_tensors = [atleast_1d (join_dims (t , None )) for t in tensor_list ]
475-
476- return join (0 , * reshaped_tensors ), packed_shapes
477-
478453 for i , input_tensor in enumerate (tensor_list ):
479454 n_dim = input_tensor .ndim
480455
@@ -488,24 +463,16 @@ def pack(
488463
489464 if n_dim == min_axes :
490465 # If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern
491- # implied by the axes. If n_before == 0, the reshape would be (-1, ...), so we need to expand at axis 0.
492- # If n_after == 0, the reshape would be (..., -1), so we need to expand at axis -1. If both are equal,
493- # the reshape will occur in the center of the tensor.
494- if n_before == 0 :
495- input_tensor = expand_dims (input_tensor , axis = 0 )
496- elif n_after == 0 :
497- input_tensor = expand_dims (input_tensor , axis = - 1 )
498- elif n_before == n_after :
499- input_tensor = expand_dims (input_tensor , axis = n_before )
500-
466+ # implied by the axes.
467+ input_tensor = expand_dims (input_tensor , axis = n_before )
501468 reshaped_tensors .append (input_tensor )
502469 continue
503470
504471 # The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1,
505472 # shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the
506473 # rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the
507474 # corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing).
508- join_axes = { n_before , n_after_packed - 1 }
475+ join_axes = range ( n_before , n_after_packed )
509476 joined = join_dims (input_tensor , tuple (join_axes ))
510477 reshaped_tensors .append (joined )
511478
@@ -560,7 +527,7 @@ def unpack(
560527
561528 split_inputs = split (
562529 packed_input ,
563- splits_size = [prod (shape ). astype ( int ) for shape in packed_shapes ],
530+ splits_size = [prod (shape , dtype = int ) for shape in packed_shapes ],
564531 n_splits = len (packed_shapes ),
565532 axis = split_axis ,
566533 )
0 commit comments