Coverage for mlprodict/onnxrt/ops_cpu/op_rfft.py: 79%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

39 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from numpy.fft import rfft 

9from ..shape_object import ShapeObject 

10from ._op import OpRun 

11from ._new_ops import OperatorSchema 

12 

13 

14class RFFT(OpRun): 

15 

16 atts = {'axis': -1} 

17 

18 def __init__(self, onnx_node, desc=None, **options): 

19 OpRun.__init__(self, onnx_node, desc=desc, 

20 expected_attributes=RFFT.atts, 

21 **options) 

22 

23 def _find_custom_operator_schema(self, op_name): 

24 if op_name == "RFFT": 

25 return RFFTSchema() 

26 raise RuntimeError( # pragma: no cover 

27 "Unable to find a schema for operator '{}'.".format(op_name)) 

28 

29 def _run(self, a, fft_length=None): # pylint: disable=W0221 

30 if fft_length is not None: 

31 fft_length = fft_length[0] 

32 y = rfft(a, fft_length, axis=self.axis) 

33 if a.dtype == numpy.float32: 

34 return (y.astype(numpy.complex64), ) 

35 if a.dtype == numpy.float64: 

36 return (y.astype(numpy.complex128), ) 

37 raise TypeError( # pragma: no cover 

38 "Unexpected input type: %r." % a.dtype) 

39 

40 def _infer_shapes(self, a, b=None): # pylint: disable=W0221,W0237 

41 if a.dtype == numpy.float32: 

42 return (ShapeObject(a.shape, dtype=numpy.complex64), ) 

43 if a.dtype == numpy.float64: 

44 return (ShapeObject(a.shape, dtype=numpy.complex128), ) 

45 raise TypeError( # pragma: no cover 

46 "Unexpected input type: %r." % a.dtype) 

47 

48 def _infer_types(self, a, b=None): # pylint: disable=W0221,W0237 

49 if a.dtype == numpy.float32: 

50 return (numpy.complex64, ) 

51 if a.dtype == numpy.float64: 

52 return (numpy.complex128, ) 

53 raise TypeError( # pragma: no cover 

54 "Unexpected input type: %r." % a.dtype) 

55 

56 def to_python(self, inputs): 

57 if len(inputs) == 1: 

58 return ('from numpy.fft import rfft', 

59 "return rfft({}, axis={})".format( 

60 inputs[0], self.axis)) 

61 return ('from numpy.fft import rfft', 

62 "return rfft({}, {}[0], axis={})".format( 

63 inputs[0], inputs[1], self.axis)) 

64 

65 

66class RFFTSchema(OperatorSchema): 

67 """ 

68 Defines a schema for operators added in this package 

69 such as @see cl FFT. 

70 """ 

71 

72 def __init__(self): 

73 OperatorSchema.__init__(self, 'RFFT') 

74 self.attributes = RFFT.atts