File tree Expand file tree Collapse file tree 2 files changed +20
-2
lines changed
Expand file tree Collapse file tree 2 files changed +20
-2
lines changed Original file line number Diff line number Diff line change @@ -3659,10 +3659,22 @@ def dft_slow(x, M):
36593659 assert_almost_equal (fft [1 , :, :], np .imag (fft_npy ))
36603660
36613661 x_val = make_xval ([3 , 4 ]).astype (np .float32 )
3662- def func (x ):
3662+ def func1 (x ):
36633663 op_ = tf .signal .rfft (x )
36643664 return tf .abs (op_ , name = _TFOUTPUT )
3665- self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
3665+ self ._run_test_case (func1 , [_OUTPUT ], {_INPUT : x_val })
3666+
3667+ def func2 (x ):
3668+ op_ = tf .signal .rfft (x )
3669+ return tf .cos (op_ , name = _TFOUTPUT )
3670+ with self .assertRaises (ValueError ):
3671+ self ._run_test_case (func2 , [_OUTPUT ], {_INPUT : x_val })
3672+
3673+ def func3 (x ):
3674+ op_ = tf .signal .rfft (x )
3675+ return tf .identity (op_ , name = _TFOUTPUT )
3676+ with self .assertRaises (ValueError ):
3677+ self ._run_test_case (func3 , [_OUTPUT ], {_INPUT : x_val })
36663678
36673679
36683680if __name__ == '__main__' :
Original file line number Diff line number Diff line change @@ -99,6 +99,12 @@ def DFT_real(x, fft_length=None):
9999 res = np.dot(cst, x)
100100 return np.transpose(res, (0, 2, 1))
101101 """
102+ consumers = ctx .find_output_consumers (node .output [0 ])
103+ consumer_types = set (op .type for op in consumers )
104+ utils .make_sure (
105+ consumer_types == {'ComplexAbs' },
106+ "Current implementation of RFFT only allows ComplexAbs as consumer not %r" ,
107+ consumer_types )
102108
103109 onnx_dtype = ctx .get_dtype (node .input [0 ])
104110 shape = ctx .get_shape (node .input [0 ])
You can’t perform that action at this time.
0 commit comments