Coverage for mlprodict/onnxrt/ops_cpu/op_fft.py: 80%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

40 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from numpy.fft import fft 

9from ..shape_object import ShapeObject 

10from ._op import OpRun 

11from ._new_ops import OperatorSchema 

12 

13 

14class FFT(OpRun): 

15 

16 atts = {'axis': -1} 

17 

18 def __init__(self, onnx_node, desc=None, **options): 

19 OpRun.__init__(self, onnx_node, desc=desc, 

20 expected_attributes=FFT.atts, 

21 **options) 

22 

23 def _find_custom_operator_schema(self, op_name): 

24 if op_name == "FFT": 

25 return FFTSchema() 

26 raise RuntimeError( # pragma: no cover 

27 "Unable to find a schema for operator '{}'.".format(op_name)) 

28 

29 def _run(self, a, fft_length=None): # pylint: disable=W0221 

30 if fft_length is not None: 

31 fft_length = fft_length[0] 

32 y = fft(a, fft_length, axis=self.axis) 

33 else: 

34 y = fft(a, axis=self.axis) 

35 if a.dtype in (numpy.float32, numpy.complex64): 

36 return (y.astype(numpy.complex64), ) 

37 if a.dtype in (numpy.float64, numpy.complex128): 

38 return (y.astype(numpy.complex128), ) 

39 raise TypeError( # pragma: no cover 

40 "Unexpected input type: %r." % a.dtype) 

41 

42 def _infer_shapes(self, a, b=None): # pylint: disable=W0221,W0237 

43 if a.dtype in (numpy.float32, numpy.complex64): 

44 return (ShapeObject(a.shape, dtype=numpy.complex64), ) 

45 if a.dtype in (numpy.float64, numpy.complex128): 

46 return (ShapeObject(a.shape, dtype=numpy.complex128), ) 

47 raise TypeError( # pragma: no cover 

48 "Unexpected input type: %r." % a.dtype) 

49 

50 def _infer_types(self, a, b=None): # pylint: disable=W0221,W0237 

51 if a.dtype in (numpy.float32, numpy.complex64): 

52 return (numpy.complex64, ) 

53 if a.dtype in (numpy.float64, numpy.complex128): 

54 return (numpy.complex128, ) 

55 raise TypeError( # pragma: no cover 

56 "Unexpected input type: %r." % a.dtype) 

57 

58 def to_python(self, inputs): 

59 if len(inputs) == 1: 

60 return ('from numpy.fft import fft', 

61 "return fft({}, axis={})".format( 

62 inputs[0], self.axis)) 

63 return ('from numpy.fft import fft', 

64 "return fft({}, {}[0], axis={})".format( 

65 inputs[0], inputs[1], self.axis)) 

66 

67 

68class FFTSchema(OperatorSchema): 

69 """ 

70 Defines a schema for operators added in this package 

71 such as @see cl FFT. 

72 """ 

73 

74 def __init__(self): 

75 OperatorSchema.__init__(self, 'FFT') 

76 self.attributes = FFT.atts