Coverage for mlprodict/onnxrt/ops_cpu/op_fft2d.py: 80%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

45 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from numpy.fft import fft2 

9from ..shape_object import ShapeObject 

10from ._op import OpRun 

11from ._new_ops import OperatorSchema 

12 

13 

14class FFT2D(OpRun): 

15 

16 atts = {'axes': [-2, -1]} 

17 

18 def __init__(self, onnx_node, desc=None, **options): 

19 OpRun.__init__(self, onnx_node, desc=desc, 

20 expected_attributes=FFT2D.atts, 

21 **options) 

22 if self.axes is not None: 

23 self.axes = tuple(self.axes) 

24 if len(self.axes) != 2: 

25 raise ValueError( # pragma: no cover 

26 "axes must a set of 1 integers not %r." % self.axes) 

27 

28 def _find_custom_operator_schema(self, op_name): 

29 if op_name == "FFT2D": 

30 return FFT2DSchema() 

31 raise RuntimeError( # pragma: no cover 

32 "Unable to find a schema for operator '{}'.".format(op_name)) 

33 

34 def _run(self, a, fft_length=None): # pylint: disable=W0221 

35 if fft_length is None: 

36 y = fft2(a, axes=self.axes) 

37 else: 

38 y = fft2(a, tuple(fft_length), axes=self.axes) 

39 if a.dtype in (numpy.float32, numpy.complex64): 

40 return (y.astype(numpy.complex64), ) 

41 if a.dtype in (numpy.float64, numpy.complex128): 

42 return (y.astype(numpy.complex128), ) 

43 raise TypeError( # pragma: no cover 

44 "Unexpected input type: %r." % a.dtype) 

45 

46 def _infer_shapes(self, a, b=None): # pylint: disable=W0221,W0237 

47 if a.dtype in (numpy.float32, numpy.complex64): 

48 return (ShapeObject(a.shape, dtype=numpy.complex64), ) 

49 if a.dtype in (numpy.float64, numpy.complex128): 

50 return (ShapeObject(a.shape, dtype=numpy.complex128), ) 

51 raise TypeError( # pragma: no cover 

52 "Unexpected input type: %r." % a.dtype) 

53 

54 def _infer_types(self, a, b=None): # pylint: disable=W0221,W0237 

55 if a.dtype in (numpy.float32, numpy.complex64): 

56 return (numpy.complex64, ) 

57 if a.dtype in (numpy.float64, numpy.complex128): 

58 return (numpy.complex128, ) 

59 raise TypeError( # pragma: no cover 

60 "Unexpected input type: %r." % a.dtype) 

61 

62 def to_python(self, inputs): 

63 if self.axes is not None: 

64 axes = tuple(self.axes) 

65 else: 

66 axes = None 

67 if len(inputs) == 1: 

68 return ('from numpy.fft import fft2', 

69 "return fft2({}, axes={})".format( 

70 inputs[0], axes)) 

71 return ('from numpy.fft import fft2', 

72 "return fft2({}, tuple({}), axes={})".format( 

73 inputs[0], inputs[1], axes)) 

74 

75 

76class FFT2DSchema(OperatorSchema): 

77 """ 

78 Defines a schema for operators added in this package 

79 such as @see cl FFT. 

80 """ 

81 

82 def __init__(self): 

83 OperatorSchema.__init__(self, 'FFT2D') 

84 self.attributes = FFT2D.atts