Skip to content

Commit c96641a

Browse files
committed
Handle functions without outputs in JITLinker
1 parent 1453ba0 commit c96641a

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

pytensor/link/basic.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -656,21 +656,37 @@ def create_jitable_thunk(
656656
thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
657657
fgraph_jit = self.jit_compile(converted_fgraph)
658658

659-
def thunk(
660-
fgraph_jit=fgraph_jit,
661-
thunk_inputs=thunk_inputs,
662-
thunk_outputs=thunk_outputs,
663-
):
664-
try:
665-
outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
666-
except Exception:
667-
# TODO: Should we add a fake node that combines all outputs,
668-
# since the error may come from any of them?
669-
raise_with_op(self.fgraph, output_nodes[0], thunk)
659+
if thunk_outputs:
670660

671-
# zip strict not specified because we are in a hot loop
672-
for o_storage, o_val in zip(thunk_outputs, outputs):
673-
o_storage[0] = o_val
661+
def thunk(
662+
fgraph_jit=fgraph_jit,
663+
thunk_inputs=thunk_inputs,
664+
thunk_outputs=thunk_outputs,
665+
):
666+
try:
667+
outputs = fgraph_jit(*(x[0] for x in thunk_inputs))
668+
except Exception:
669+
# TODO: Should we add a fake node that combines all outputs,
670+
# since the error may come from any of them?
671+
raise_with_op(self.fgraph, output_nodes[0], thunk)
672+
673+
# zip strict not specified because we are in a hot loop
674+
for o_storage, o_val in zip(thunk_outputs, outputs):
675+
o_storage[0] = o_val
676+
677+
else:
678+
# Edge case - functions without outputs
679+
def thunk(
680+
fgraph_jit=fgraph_jit,
681+
thunk_inputs=thunk_inputs,
682+
thunk_outputs=thunk_outputs,
683+
):
684+
try:
685+
res = fgraph_jit(*(x[0] for x in thunk_inputs))
686+
except Exception:
687+
raise_with_op(self.fgraph, output_nodes[0], thunk)
688+
assert res is None
689+
return thunk_outputs
674690

675691
thunk.inputs = thunk_inputs
676692
thunk.outputs = thunk_outputs

0 commit comments

Comments
 (0)