Coverage for mlprodict/onnxrt/ops_cpu/op_rfft.py: 79%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from numpy.fft import rfft
9from ..shape_object import ShapeObject
10from ._op import OpRun
11from ._new_ops import OperatorSchema
14class RFFT(OpRun):
16 atts = {'axis': -1}
18 def __init__(self, onnx_node, desc=None, **options):
19 OpRun.__init__(self, onnx_node, desc=desc,
20 expected_attributes=RFFT.atts,
21 **options)
23 def _find_custom_operator_schema(self, op_name):
24 if op_name == "RFFT":
25 return RFFTSchema()
26 raise RuntimeError( # pragma: no cover
27 "Unable to find a schema for operator '{}'.".format(op_name))
29 def _run(self, a, fft_length=None): # pylint: disable=W0221
30 if fft_length is not None:
31 fft_length = fft_length[0]
32 y = rfft(a, fft_length, axis=self.axis)
33 if a.dtype == numpy.float32:
34 return (y.astype(numpy.complex64), )
35 if a.dtype == numpy.float64:
36 return (y.astype(numpy.complex128), )
37 raise TypeError( # pragma: no cover
38 "Unexpected input type: %r." % a.dtype)
40 def _infer_shapes(self, a, b=None): # pylint: disable=W0221,W0237
41 if a.dtype == numpy.float32:
42 return (ShapeObject(a.shape, dtype=numpy.complex64), )
43 if a.dtype == numpy.float64:
44 return (ShapeObject(a.shape, dtype=numpy.complex128), )
45 raise TypeError( # pragma: no cover
46 "Unexpected input type: %r." % a.dtype)
48 def _infer_types(self, a, b=None): # pylint: disable=W0221,W0237
49 if a.dtype == numpy.float32:
50 return (numpy.complex64, )
51 if a.dtype == numpy.float64:
52 return (numpy.complex128, )
53 raise TypeError( # pragma: no cover
54 "Unexpected input type: %r." % a.dtype)
56 def to_python(self, inputs):
57 if len(inputs) == 1:
58 return ('from numpy.fft import rfft',
59 "return rfft({}, axis={})".format(
60 inputs[0], self.axis))
61 return ('from numpy.fft import rfft',
62 "return rfft({}, {}[0], axis={})".format(
63 inputs[0], inputs[1], self.axis))
66class RFFTSchema(OperatorSchema):
67 """
68 Defines a schema for operators added in this package
69 such as @see cl FFT.
70 """
72 def __init__(self):
73 OperatorSchema.__init__(self, 'RFFT')
74 self.attributes = RFFT.atts