Coverage for mlprodict/onnxrt/ops_cpu/op_fft.py: 80%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from numpy.fft import fft
9from ..shape_object import ShapeObject
10from ._op import OpRun
11from ._new_ops import OperatorSchema
14class FFT(OpRun):
16 atts = {'axis': -1}
18 def __init__(self, onnx_node, desc=None, **options):
19 OpRun.__init__(self, onnx_node, desc=desc,
20 expected_attributes=FFT.atts,
21 **options)
23 def _find_custom_operator_schema(self, op_name):
24 if op_name == "FFT":
25 return FFTSchema()
26 raise RuntimeError( # pragma: no cover
27 "Unable to find a schema for operator '{}'.".format(op_name))
29 def _run(self, a, fft_length=None): # pylint: disable=W0221
30 if fft_length is not None:
31 fft_length = fft_length[0]
32 y = fft(a, fft_length, axis=self.axis)
33 else:
34 y = fft(a, axis=self.axis)
35 if a.dtype in (numpy.float32, numpy.complex64):
36 return (y.astype(numpy.complex64), )
37 if a.dtype in (numpy.float64, numpy.complex128):
38 return (y.astype(numpy.complex128), )
39 raise TypeError( # pragma: no cover
40 "Unexpected input type: %r." % a.dtype)
42 def _infer_shapes(self, a, b=None): # pylint: disable=W0221,W0237
43 if a.dtype in (numpy.float32, numpy.complex64):
44 return (ShapeObject(a.shape, dtype=numpy.complex64), )
45 if a.dtype in (numpy.float64, numpy.complex128):
46 return (ShapeObject(a.shape, dtype=numpy.complex128), )
47 raise TypeError( # pragma: no cover
48 "Unexpected input type: %r." % a.dtype)
50 def _infer_types(self, a, b=None): # pylint: disable=W0221,W0237
51 if a.dtype in (numpy.float32, numpy.complex64):
52 return (numpy.complex64, )
53 if a.dtype in (numpy.float64, numpy.complex128):
54 return (numpy.complex128, )
55 raise TypeError( # pragma: no cover
56 "Unexpected input type: %r." % a.dtype)
58 def to_python(self, inputs):
59 if len(inputs) == 1:
60 return ('from numpy.fft import fft',
61 "return fft({}, axis={})".format(
62 inputs[0], self.axis))
63 return ('from numpy.fft import fft',
64 "return fft({}, {}[0], axis={})".format(
65 inputs[0], inputs[1], self.axis))
68class FFTSchema(OperatorSchema):
69 """
70 Defines a schema for operators added in this package
71 such as @see cl FFT.
72 """
74 def __init__(self):
75 OperatorSchema.__init__(self, 'FFT')
76 self.attributes = FFT.atts