Coverage for mlprodict/onnxrt/ops_cpu/op_complex_abs.py: 88%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

34 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ..shape_object import ShapeObject 

9from ._op import OpRun 

10from ._new_ops import OperatorSchema 

11 

12 

13class ComplexAbs(OpRun): 

14 

15 def __init__(self, onnx_node, desc=None, **options): 

16 OpRun.__init__(self, onnx_node, desc=desc, **options) 

17 

18 def _find_custom_operator_schema(self, op_name): 

19 if op_name == "ComplexAbs": 

20 return ComplexAbsSchema() 

21 raise RuntimeError( # pragma: no cover 

22 "Unable to find a schema for operator '{}'.".format(op_name)) 

23 

24 def _run(self, x): # pylint: disable=W0221 

25 y = numpy.absolute(x) 

26 if x.dtype == numpy.complex64: 

27 y = y.astype(numpy.float32) 

28 elif x.dtype == numpy.complex128: 

29 y = y.astype(numpy.float64) 

30 else: 

31 raise TypeError( # pragma: no cover 

32 "Unexpected input type for x: %r." % x.dtype) 

33 return (y, ) 

34 

35 def _infer_shapes(self, x): # pylint: disable=W0221,W0237 

36 if x.dtype == numpy.complex64: 

37 return (ShapeObject(x.shape, numpy.float32), ) 

38 elif x.dtype == numpy.complex128: 

39 return (ShapeObject(x.shape, numpy.float64), ) 

40 else: 

41 raise TypeError( # pragma: no cover 

42 "Unexpected input type for x: %r." % x.dtype) 

43 

44 def _infer_types(self, x): # pylint: disable=W0221,W0237 

45 if x == numpy.complex64: 

46 return (numpy.float32, ) 

47 elif x == numpy.complex128: 

48 return (numpy.float64, ) 

49 else: 

50 raise TypeError( # pragma: no cover 

51 "Unexpected input type for x: %r." % x) 

52 

53 def to_python(self, inputs): 

54 return self._to_python_numpy(inputs, 'absolute') 

55 

56 

57class ComplexAbsSchema(OperatorSchema): 

58 """ 

59 Defines a schema for operators added in this package 

60 such as @see cl ComplexAbs. 

61 """ 

62 

63 def __init__(self): 

64 OperatorSchema.__init__(self, 'ComplexAbs') 

65 self.attributes = {}