@@ -36,13 +36,6 @@ def make_dft_constant(length, dtype, fft_length):
3636@tf_op ("RFFT" )
3737class RFFTOp :
3838 # support more dtype
39- supported_dtypes = [
40- onnx_pb .TensorProto .FLOAT ,
41- onnx_pb .TensorProto .FLOAT16 ,
42- onnx_pb .TensorProto .DOUBLE ,
43- onnx_pb .TensorProto .COMPLEX64 ,
44- onnx_pb .TensorProto .COMPLEX128 ,
45- ]
4639
4740 @classmethod
4841 def version_1 (cls , ctx , node , ** kwargs ):
@@ -99,6 +92,13 @@ def DFT_real(x, fft_length=None):
9992 res = np.dot(cst, x)
10093 return np.transpose(res, (0, 2, 1))
10194 """
95+ supported_dtypes = [
96+ onnx_pb .TensorProto .FLOAT ,
97+ onnx_pb .TensorProto .FLOAT16 ,
98+ onnx_pb .TensorProto .DOUBLE ,
99+ onnx_pb .TensorProto .COMPLEX64 ,
100+ onnx_pb .TensorProto .COMPLEX128 ,
101+ ]
102102 consumers = ctx .find_output_consumers (node .output [0 ])
103103 consumer_types = set (op .type for op in consumers )
104104 utils .make_sure (
@@ -107,6 +107,7 @@ def DFT_real(x, fft_length=None):
107107 consumer_types )
108108
109109 onnx_dtype = ctx .get_dtype (node .input [0 ])
110+ utils .make_sure (onnx_dtype in supported_dtypes , "Unsupported input type." )
110111 shape = ctx .get_shape (node .input [0 ])
111112 np_dtype = utils .map_onnx_to_numpy_type (onnx_dtype )
112113 shape_n = shape [- 1 ]
@@ -164,13 +165,6 @@ def DFT_real(x, fft_length=None):
164165@tf_op ("ComplexAbs" )
165166class ComplexAbsOp :
166167 # support more dtype
167- supported_dtypes = [
168- onnx_pb .TensorProto .FLOAT ,
169- onnx_pb .TensorProto .FLOAT16 ,
170- onnx_pb .TensorProto .DOUBLE ,
171- onnx_pb .TensorProto .COMPLEX64 ,
172- onnx_pb .TensorProto .COMPLEX128 ,
173- ]
174168
175169 @classmethod
176170 def any_version (cls , opset , ctx , node , ** kwargs ):
@@ -180,7 +174,15 @@ def any_version(cls, opset, ctx, node, **kwargs):
180174 it assumes the first dimension means real part (0)
181175 and imaginary part (1, :, :...).
182176 """
177+ supported_dtypes = [
178+ onnx_pb .TensorProto .FLOAT ,
179+ onnx_pb .TensorProto .FLOAT16 ,
180+ onnx_pb .TensorProto .DOUBLE ,
181+ onnx_pb .TensorProto .COMPLEX64 ,
182+ onnx_pb .TensorProto .COMPLEX128 ,
183+ ]
183184 onnx_dtype = ctx .get_dtype (node .input [0 ])
185+ utils .make_sure (onnx_dtype in supported_dtypes , "Unsupported input type." )
184186 shape = ctx .get_shape (node .input [0 ])
185187 np_dtype = utils .map_onnx_to_numpy_type (onnx_dtype )
186188 utils .make_sure (shape [0 ] == 2 , "ComplexAbs expected the first dimension to be 2 but shape is %r" , shape )
0 commit comments