Skip to content

Commit e5b1d3e

Browse files
committed
Optimize: Guard against unsupported input types
1 parent 1d13f8c commit e5b1d3e

File tree

2 files changed

+183
-20
lines changed

2 files changed

+183
-20
lines changed

pytensor/tensor/optimize.py

Lines changed: 76 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88
import pytensor.scalar as ps
99
from pytensor.compile.function import function
10-
from pytensor.gradient import grad, jacobian
10+
from pytensor.gradient import grad, grad_not_implemented, jacobian
1111
from pytensor.graph.basic import Apply, Constant
1212
from pytensor.graph.fg import FunctionGraph
1313
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
1414
from pytensor.graph.replace import graph_replace
1515
from pytensor.graph.traversal import ancestors, truncated_graph_inputs
16+
from pytensor.scalar import ScalarType, ScalarVariable
1617
from pytensor.tensor.basic import (
1718
atleast_2d,
1819
concatenate,
@@ -22,6 +23,7 @@
2223
)
2324
from pytensor.tensor.math import dot
2425
from pytensor.tensor.slinalg import solve
26+
from pytensor.tensor.type import DenseTensorType
2527
from pytensor.tensor.variable import TensorVariable, Variable
2628

2729

@@ -140,23 +142,19 @@ def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
140142

141143

142144
def _get_parameter_grads_from_vector(
143-
grad_wrt_args_vector: Variable,
144-
x_star: Variable,
145-
args: Sequence[Variable],
145+
grad_wrt_args_vector: TensorVariable,
146+
x_star: TensorVariable,
147+
args: Sequence[TensorVariable | ScalarVariable],
146148
output_grad: Variable,
147149
):
148150
"""
149151
Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
150152
returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
151153
"""
152-
grad_wrt_args_vector = cast(TensorVariable, grad_wrt_args_vector)
153-
x_star = cast(TensorVariable, x_star)
154-
155154
cursor = 0
156155
grad_wrt_args = []
157156

158157
for arg in args:
159-
arg = cast(TensorVariable, arg)
160158
arg_shape = arg.shape
161159
arg_size = arg_shape.prod()
162160
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
@@ -375,14 +373,18 @@ def __init__(
375373
method: str = "brent",
376374
optimizer_kwargs: dict | None = None,
377375
):
378-
if not cast(TensorVariable, x).ndim == 0:
376+
if not (isinstance(x, TensorVariable) and x.ndim == 0):
379377
raise ValueError(
380378
"The variable `x` must be a scalar (0-dimensional) tensor for minimize_scalar."
381379
)
382-
if not cast(TensorVariable, objective).ndim == 0:
380+
if not (isinstance(objective, TensorVariable) and objective.ndim == 0):
383381
raise ValueError(
384382
"The objective function must be a scalar (0-dimensional) tensor for minimize_scalar."
385383
)
384+
if x not in ancestors([objective]):
385+
raise ValueError(
386+
"The variable `x` must be an input to the computational graph of the objective function."
387+
)
386388
self.fgraph = FunctionGraph([x, *args], [objective])
387389

388390
self.method = method
@@ -416,7 +418,19 @@ def perform(self, node, inputs, outputs):
416418
outputs[1][0] = np.bool_(res.success)
417419

418420
def L_op(self, inputs, outputs, output_grads):
421+
# TODO: Handle disconnected inputs, instead of zeroing them out or failing for unsupported types
419422
x, *args = inputs
423+
if non_supported_types := tuple(
424+
inp.type
425+
for inp in inputs
426+
if not isinstance(inp.type, DenseTensorType | ScalarType)
427+
):
428+
# TODO: Support SparseTensorTypes
429+
# TODO: Remaining types are likely just disconnected anyway
430+
msg = f"Minimize gradient not implemented due to inputs of type {non_supported_types}"
431+
return [
432+
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
433+
]
420434
x_star, _ = outputs
421435
output_grad, _ = output_grads
422436

@@ -468,7 +482,6 @@ def minimize_scalar(
468482
Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
469483
value, based on the requested convergence criteria.
470484
"""
471-
472485
args = _find_optimization_parameters(objective, x)
473486

474487
minimize_scalar_op = MinimizeScalarOp(
@@ -499,7 +512,11 @@ def __init__(
499512
use_vectorized_jac: bool = False,
500513
optimizer_kwargs: dict | None = None,
501514
):
502-
if not cast(TensorVariable, objective).ndim == 0:
515+
if not (isinstance(x, TensorVariable) and x.ndim in (0, 1)):
516+
raise ValueError(
517+
"The variable `x` must be a scalar or vector (0-or-1-dimensional) tensor for minimize."
518+
)
519+
if not (isinstance(objective, TensorVariable) and objective.ndim == 0):
503520
raise ValueError(
504521
"The objective function must be a scalar (0-dimensional) tensor for minimize."
505522
)
@@ -570,7 +587,19 @@ def perform(self, node, inputs, outputs):
570587
outputs[1][0] = np.bool_(res.success)
571588

572589
def L_op(self, inputs, outputs, output_grads):
590+
# TODO: Handle disconnected inputs, instead of zeroing them out or failing for unsupported types
573591
x, *args = inputs
592+
if non_supported_types := tuple(
593+
inp.type
594+
for inp in inputs
595+
if not isinstance(inp.type, DenseTensorType | ScalarType)
596+
):
597+
# TODO: Support SparseTensorTypes
598+
# TODO: Remaining types are likely just disconnected anyway
599+
msg = f"MinimizeOp gradient not implemented due to inputs of type {non_supported_types}"
600+
return [
601+
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
602+
]
574603
x_star, _success = outputs
575604
output_grad, _ = output_grads
576605

@@ -672,13 +701,15 @@ def __init__(
672701
hess: bool = False,
673702
optimizer_kwargs=None,
674703
):
675-
if not equation.ndim == 0:
704+
if not (isinstance(variables, TensorVariable) and variables.ndim == 0):
705+
raise ValueError(
706+
"The variable `x` must be a scalar (0-dimensional) tensor for root_scalar."
707+
)
708+
if not (isinstance(equation, TensorVariable) and equation.ndim == 0):
676709
raise ValueError(
677710
"The equation must be a scalar (0-dimensional) tensor for root_scalar."
678711
)
679-
if not isinstance(variables, Variable) or variables not in ancestors(
680-
[equation]
681-
):
712+
if variables not in ancestors([equation]):
682713
raise ValueError(
683714
"The variable `variables` must be an input to the computational graph of the equation."
684715
)
@@ -741,7 +772,19 @@ def perform(self, node, inputs, outputs):
741772
outputs[1][0] = np.bool_(res.converged)
742773

743774
def L_op(self, inputs, outputs, output_grads):
775+
# TODO: Handle disconnected inputs, instead of zeroing them out or failing for unsupported types
744776
x, *args = inputs
777+
if non_supported_types := tuple(
778+
inp.type
779+
for inp in inputs
780+
if not isinstance(inp.type, DenseTensorType | ScalarType)
781+
):
782+
# TODO: Support SparseTensorTypes
783+
# TODO: Remaining types are likely just disconnected anyway
784+
msg = f"RootScalarOp gradient not implemented due to inputs of type {non_supported_types}"
785+
return [
786+
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
787+
]
745788
x_star, _ = outputs
746789
output_grad, _ = output_grads
747790

@@ -833,7 +876,11 @@ def __init__(
833876
optimizer_kwargs: dict | None = None,
834877
use_vectorized_jac: bool = False,
835878
):
836-
if cast(TensorVariable, variables).ndim != cast(TensorVariable, equations).ndim:
879+
if not isinstance(variables, TensorVariable):
880+
raise ValueError("The variable `variables` must be a tensor for root.")
881+
if not isinstance(equations, TensorVariable):
882+
raise ValueError("The equations must be a tensor for root.")
883+
if variables.ndim != equations.ndim:
837884
raise ValueError(
838885
"The variable `variables` must have the same number of dimensions as the equations."
839886
)
@@ -922,7 +969,19 @@ def L_op(
922969
outputs: Sequence[Variable],
923970
output_grads: Sequence[Variable],
924971
) -> list[Variable]:
972+
# TODO: Handle disconnected inputs, instead of zeroing them out or failing for unsupported types
925973
x, *args = inputs
974+
if non_supported_types := tuple(
975+
inp.type
976+
for inp in inputs
977+
if not isinstance(inp.type, DenseTensorType | ScalarType)
978+
):
979+
# TODO: Support SparseTensorTypes
980+
# TODO: Remaining types are likely just disconnected anyway
981+
msg = f"RootOp gradient not implemented due to inputs of type {non_supported_types}"
982+
return [
983+
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
984+
]
926985
x_star, _ = outputs
927986
output_grad, _ = output_grads
928987

tests/tensor/test_optimize.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
import pytensor
55
import pytensor.tensor as pt
6-
from pytensor import config, function
7-
from pytensor.graph import Apply, Op
8-
from pytensor.tensor import scalar
6+
from pytensor import Variable, config, function
7+
from pytensor.gradient import NullTypeGradError, disconnected_type
8+
from pytensor.graph import Apply, Op, Type
9+
from pytensor.scalar import float64
10+
from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar
911
from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar
1012
from tests import unittest_tools as utt
1113

@@ -248,3 +250,105 @@ def L_op(self, inputs, outputs, out_grads):
248250
np.testing.assert_allclose(
249251
opt_x_res, 0, atol=1e-15 if floatX == "float64" else 1e-6
250252
)
253+
254+
255+
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
256+
def test_minimize_grad_scalar_arg(optimize_op):
257+
# Regression test for https://github.com/pymc-devs/pytensor/pull/1744
258+
x = scalar("x")
259+
theta = float64("theta")
260+
obj = tensor_from_scalar((scalar_from_tensor(x) + theta) ** 2)
261+
x0, _ = optimize_op(obj, x)
262+
263+
# Confirm theta is a direct input to the node
264+
assert x0.owner.inputs[1] is theta
265+
266+
grad_wrt_theta = pt.grad(x0, theta)
267+
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: np.e}), -1)
268+
269+
270+
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
271+
def test_minimize_grad_disconnected_numerical_inp(optimize_op):
272+
x = scalar("x", dtype="float64")
273+
theta = scalar("theta", dtype="int64")
274+
obj = alloc(x**2, theta).sum() # repeat theta times and sum
275+
x0, _ = optimize_op(obj, x)
276+
277+
# Confirm theta is a direct input to the node
278+
assert x0.owner.inputs[1] is theta
279+
280+
# This should technically raise, but does not right now
281+
grad_wrt_theta = pt.grad(x0, theta, disconnected_inputs="raise")
282+
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: 5}), 0)
283+
284+
# This should work even if the previous one raised
285+
grad_wrt_theta = pt.grad(x0, theta, disconnected_inputs="ignore")
286+
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: 5}), 0)
287+
288+
289+
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
290+
def test_minimize_grad_disconnected_non_numerical_inp(optimize_op):
291+
class StrType(Type):
292+
def filter(self, x, **kwargs):
293+
if isinstance(x, str):
294+
return x
295+
raise TypeError
296+
297+
class SmileOrFrown(Op):
298+
def make_node(self, x, str_emoji):
299+
return Apply(self, [x, str_emoji], [x.type()])
300+
301+
def perform(self, node, inputs, output_storage):
302+
[x, str_emoji] = inputs
303+
match str_emoji:
304+
case ":)":
305+
out = np.array(x)
306+
case ":(":
307+
out = np.array(-x)
308+
case _:
309+
ValueError("str_emoji must be a smile or a frown")
310+
output_storage[0][0] = out
311+
312+
def connection_pattern(self, node):
313+
# Gradient connected only to first input
314+
return [[True], [False]]
315+
316+
def L_op(self, inputs, outputs, output_gradients):
317+
[_x, str_emoji] = inputs
318+
[g] = output_gradients
319+
return [
320+
self(g, str_emoji),
321+
disconnected_type(),
322+
]
323+
324+
# We could try to use real types like NoneTypeT or SliceType, but this is more robust to future API changes
325+
str_type = StrType()
326+
smile_or_frown = SmileOrFrown()
327+
328+
x = scalar("x", dtype="float64")
329+
num_theta = pt.scalar("num_theta", dtype="float64")
330+
str_theta = Variable(str_type, None, None, name="str_theta")
331+
obj = (smile_or_frown(x, str_theta) + num_theta) ** 2
332+
x_star, _ = optimize_op(obj, x)
333+
334+
# Confirm thetas are direct inputs to the node
335+
assert set(x_star.owner.inputs[1:]) == {num_theta, str_theta}
336+
337+
# Confirm forward pass works, no point in worrying about gradient otherwise
338+
np.testing.assert_allclose(
339+
x_star.eval({x: np.pi, num_theta: np.e, str_theta: ":)"}),
340+
-np.e,
341+
)
342+
np.testing.assert_allclose(
343+
x_star.eval({x: np.pi, num_theta: np.e, str_theta: ":("}),
344+
np.e,
345+
)
346+
347+
with pytest.raises(NullTypeGradError):
348+
pt.grad(x_star, str_theta, disconnected_inputs="raise")
349+
350+
# This could be supported, but it is not right now.
351+
with pytest.raises(NullTypeGradError):
352+
_grad_wrt_num_theta = pt.grad(x_star, num_theta, disconnected_inputs="raise")
353+
# np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":)"}), -1)
354+
# np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":("}), 1)

0 commit comments

Comments
 (0)