Coverage for mlprodict/onnxrt/ops_cpu/op_fft2d.py: 80%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from numpy.fft import fft2
9from ..shape_object import ShapeObject
10from ._op import OpRun
11from ._new_ops import OperatorSchema
14class FFT2D(OpRun):
16 atts = {'axes': [-2, -1]}
18 def __init__(self, onnx_node, desc=None, **options):
19 OpRun.__init__(self, onnx_node, desc=desc,
20 expected_attributes=FFT2D.atts,
21 **options)
22 if self.axes is not None:
23 self.axes = tuple(self.axes)
24 if len(self.axes) != 2:
25 raise ValueError( # pragma: no cover
26 "axes must a set of 1 integers not %r." % self.axes)
28 def _find_custom_operator_schema(self, op_name):
29 if op_name == "FFT2D":
30 return FFT2DSchema()
31 raise RuntimeError( # pragma: no cover
32 "Unable to find a schema for operator '{}'.".format(op_name))
34 def _run(self, a, fft_length=None): # pylint: disable=W0221
35 if fft_length is None:
36 y = fft2(a, axes=self.axes)
37 else:
38 y = fft2(a, tuple(fft_length), axes=self.axes)
39 if a.dtype in (numpy.float32, numpy.complex64):
40 return (y.astype(numpy.complex64), )
41 if a.dtype in (numpy.float64, numpy.complex128):
42 return (y.astype(numpy.complex128), )
43 raise TypeError( # pragma: no cover
44 "Unexpected input type: %r." % a.dtype)
46 def _infer_shapes(self, a, b=None): # pylint: disable=W0221,W0237
47 if a.dtype in (numpy.float32, numpy.complex64):
48 return (ShapeObject(a.shape, dtype=numpy.complex64), )
49 if a.dtype in (numpy.float64, numpy.complex128):
50 return (ShapeObject(a.shape, dtype=numpy.complex128), )
51 raise TypeError( # pragma: no cover
52 "Unexpected input type: %r." % a.dtype)
54 def _infer_types(self, a, b=None): # pylint: disable=W0221,W0237
55 if a.dtype in (numpy.float32, numpy.complex64):
56 return (numpy.complex64, )
57 if a.dtype in (numpy.float64, numpy.complex128):
58 return (numpy.complex128, )
59 raise TypeError( # pragma: no cover
60 "Unexpected input type: %r." % a.dtype)
62 def to_python(self, inputs):
63 if self.axes is not None:
64 axes = tuple(self.axes)
65 else:
66 axes = None
67 if len(inputs) == 1:
68 return ('from numpy.fft import fft2',
69 "return fft2({}, axes={})".format(
70 inputs[0], axes))
71 return ('from numpy.fft import fft2',
72 "return fft2({}, tuple({}), axes={})".format(
73 inputs[0], inputs[1], axes))
76class FFT2DSchema(OperatorSchema):
77 """
78 Defines a schema for operators added in this package
79 such as @see cl FFT.
80 """
82 def __init__(self):
83 OperatorSchema.__init__(self, 'FFT2D')
84 self.attributes = FFT2D.atts