We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c9b4219 commit 359e69cCopy full SHA for 359e69c
tf2onnx/onnx_opset/tensor.py
@@ -61,7 +61,16 @@ def _wrap_concat_with_cast(ctx, node):
61
class Size:
62
@classmethod
63
def version_1(cls, ctx, node, **kwargs):
64
- ctx.set_dtype(node.output[0], onnx_pb.TensorProto.INT64)
+ output_name = node.output[0]
65
+ dtype = ctx.get_dtype(output_name)
66
+ # TF size can output int32 or int64 but onnx only does int 64
67
+ if dtype != onnx_pb.TensorProto.INT64:
68
+ ctx.set_dtype(output_name, onnx_pb.TensorProto.INT64)
69
+ output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=node.child_name(),
70
+ to=dtype)
71
+ ctx.set_dtype(output_cast.output[0], dtype)
72
+ ctx.copy_shape(output_name, output_cast.output[0])
73
+
74
75
76
@tf_op("Flatten")
0 commit comments