Skip to content

Commit 82042b9

Browse files
Add rewrite to test
1 parent 0b600ac commit 82042b9

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tests/tensor/test_reshape_ops.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytensor
55
from pytensor import config, function
66
from pytensor import tensor as pt
7-
from pytensor.graph import vectorize_graph
7+
from pytensor.graph import rewrite_graph, vectorize_graph
88
from pytensor.tensor.shape_ops import (
99
_analyze_axes_list,
1010
join_dims,
@@ -115,13 +115,15 @@ def test_make_replacements_with_pack_unpack():
115115
new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes)
116116

117117
loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
118-
fn = pytensor.function([new_input, x, y, z], loss, mode="FAST_COMPILE")
118+
rewrite_graph(loss, include=("ShapeOpt", "specialize"))
119+
120+
fn = pytensor.function([new_input], loss, mode="FAST_COMPILE")
119121

120122
input_vals = [
121123
rng.normal(size=(var.type.shape)).astype(config.floatX) for var in [x, y, z]
122124
]
123125
flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0)
124-
output_val = fn(flat_inputs, *input_vals)
126+
output_val = fn(flat_inputs)
125127

126128
assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2)
127129

0 commit comments

Comments
 (0)