@@ -704,15 +704,14 @@ def atan2(y, x):
704704
705705@tf_op ("InvertPermutation" )
706706class InvertPermutationOp :
707- supported_dtypes = [
708- onnx_pb .TensorProto .INT32 ,
709- onnx_pb .TensorProto .INT64 ,
710- ]
711707
712708 @classmethod
713709 def version_11 (cls , ctx , node , ** kwargs ):
714710
711+ supported_dtypes = [onnx_pb .TensorProto .INT32 , onnx_pb .TensorProto .INT64 ]
715712 onnx_dtype = ctx .get_dtype (node .input [0 ])
713+ utils .make_sure (onnx_dtype in supported_dtypes , "InvertPermutation only applies on INT32, INT64." )
714+
716715 shape = ctx .get_shape (node .input [0 ])
717716
718717 shape_node = ctx .make_node (
@@ -721,12 +720,9 @@ def version_11(cls, ctx, node, **kwargs):
721720 neg_node = ctx .make_node (
722721 "Neg" , inputs = node .input , name = utils .make_name (node .name + '_neg' ))
723722
724- topk_unused = utils .make_name (node .name + '_topk' )
725- topk_indices = utils .make_name (node .name + '_indices' )
726- outputs = [topk_unused , utils .port_name (topk_indices , 1 )]
727723 topk_node = ctx .make_node (
728724 "TopK" , inputs = [neg_node .output [0 ], shape_node .output [0 ]],
729- name = utils .make_name (node .name + '_topk' ), outputs = outputs )
725+ name = utils .make_name (node .name + '_topk' ), output_count = 2 )
730726
731727 ctx .remove_node (node .name )
732728
0 commit comments