@@ -656,21 +656,37 @@ def create_jitable_thunk(
656656 thunk_outputs = [storage_map [n ] for n in self .fgraph .outputs ]
657657 fgraph_jit = self .jit_compile (converted_fgraph )
658658
659- def thunk (
660- fgraph_jit = fgraph_jit ,
661- thunk_inputs = thunk_inputs ,
662- thunk_outputs = thunk_outputs ,
663- ):
664- try :
665- outputs = fgraph_jit (* (x [0 ] for x in thunk_inputs ))
666- except Exception :
667- # TODO: Should we add a fake node that combines all outputs,
668- # since the error may come from any of them?
669- raise_with_op (self .fgraph , output_nodes [0 ], thunk )
659+ if thunk_outputs :
670660
671- # zip strict not specified because we are in a hot loop
672- for o_storage , o_val in zip (thunk_outputs , outputs ):
673- o_storage [0 ] = o_val
661+ def thunk (
662+ fgraph_jit = fgraph_jit ,
663+ thunk_inputs = thunk_inputs ,
664+ thunk_outputs = thunk_outputs ,
665+ ):
666+ try :
667+ outputs = fgraph_jit (* (x [0 ] for x in thunk_inputs ))
668+ except Exception :
669+ # TODO: Should we add a fake node that combines all outputs,
670+ # since the error may come from any of them?
671+ raise_with_op (self .fgraph , output_nodes [0 ], thunk )
672+
673+ # zip strict not specified because we are in a hot loop
674+ for o_storage , o_val in zip (thunk_outputs , outputs ):
675+ o_storage [0 ] = o_val
676+
677+ else :
678+ # Edge case - functions without outputs
679+ def thunk (
680+ fgraph_jit = fgraph_jit ,
681+ thunk_inputs = thunk_inputs ,
682+ thunk_outputs = thunk_outputs ,
683+ ):
684+ try :
685+ res = fgraph_jit (* (x [0 ] for x in thunk_inputs ))
686+ except Exception :
687+ raise_with_op (self .fgraph , output_nodes [0 ], thunk )
688+ assert res is None
689+ return thunk_outputs
674690
675691 thunk .inputs = thunk_inputs
676692 thunk .outputs = thunk_outputs
0 commit comments