@@ -700,3 +700,34 @@ def atan2(y, x):
700700 op_name_scope = node .name + 'all' ,
701701 shapes = [shape ], dtypes = [onnx_dtype ])
702702 ctx .replace_all_inputs (node .output [0 ], last_node .output [0 ]) # ops=ctx.get_nodes()
703+
704+
705+ @tf_op ("InvertPermutation" )
706+ class InvertPermutationOp :
707+
708+ @classmethod
709+ def version_11 (cls , ctx , node , ** kwargs ):
710+
711+ supported_dtypes = [onnx_pb .TensorProto .INT32 , onnx_pb .TensorProto .INT64 ]
712+ onnx_dtype = ctx .get_dtype (node .input [0 ])
713+ utils .make_sure (onnx_dtype in supported_dtypes , "InvertPermutation only applies on INT32, INT64." )
714+
715+ shape = ctx .get_shape (node .input [0 ])
716+
717+ shape_node = ctx .make_node (
718+ "Shape" , inputs = node .input , name = utils .make_name (node .name + '_shape' ))
719+
720+ neg_node = ctx .make_node (
721+ "Neg" , inputs = node .input , name = utils .make_name (node .name + '_neg' ))
722+
723+ topk_node = ctx .make_node (
724+ "TopK" , inputs = [neg_node .output [0 ], shape_node .output [0 ]],
725+ name = utils .make_name (node .name + '_topk' ), output_count = 2 )
726+
727+ ctx .remove_node (node .name )
728+
729+ last_node = ctx .make_node (
730+ "Identity" , inputs = topk_node .output [1 :], name = utils .make_name (node .name + '_indices' ),
731+ shapes = [shape ], dtypes = [onnx_dtype ])
732+
733+ ctx .replace_all_inputs (node .output [0 ], last_node .output [0 ]) # ops=ctx.get_nodes()
0 commit comments