Coverage for mlprodict/onnxrt/ops_cpu/op_complex_abs.py: 88%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ..shape_object import ShapeObject
9from ._op import OpRun
10from ._new_ops import OperatorSchema
13class ComplexAbs(OpRun):
15 def __init__(self, onnx_node, desc=None, **options):
16 OpRun.__init__(self, onnx_node, desc=desc, **options)
18 def _find_custom_operator_schema(self, op_name):
19 if op_name == "ComplexAbs":
20 return ComplexAbsSchema()
21 raise RuntimeError( # pragma: no cover
22 "Unable to find a schema for operator '{}'.".format(op_name))
24 def _run(self, x): # pylint: disable=W0221
25 y = numpy.absolute(x)
26 if x.dtype == numpy.complex64:
27 y = y.astype(numpy.float32)
28 elif x.dtype == numpy.complex128:
29 y = y.astype(numpy.float64)
30 else:
31 raise TypeError( # pragma: no cover
32 "Unexpected input type for x: %r." % x.dtype)
33 return (y, )
35 def _infer_shapes(self, x): # pylint: disable=W0221,W0237
36 if x.dtype == numpy.complex64:
37 return (ShapeObject(x.shape, numpy.float32), )
38 elif x.dtype == numpy.complex128:
39 return (ShapeObject(x.shape, numpy.float64), )
40 else:
41 raise TypeError( # pragma: no cover
42 "Unexpected input type for x: %r." % x.dtype)
44 def _infer_types(self, x): # pylint: disable=W0221,W0237
45 if x == numpy.complex64:
46 return (numpy.float32, )
47 elif x == numpy.complex128:
48 return (numpy.float64, )
49 else:
50 raise TypeError( # pragma: no cover
51 "Unexpected input type for x: %r." % x)
53 def to_python(self, inputs):
54 return self._to_python_numpy(inputs, 'absolute')
57class ComplexAbsSchema(OperatorSchema):
58 """
59 Defines a schema for operators added in this package
60 such as @see cl ComplexAbs.
61 """
63 def __init__(self):
64 OperatorSchema.__init__(self, 'ComplexAbs')
65 self.attributes = {}