Coverage for mlprodict/onnxrt/ops_cpu/op_constant.py: 86%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun
10from ..shape_object import ShapeObject
13def _check_dtype(val):
14 a = val.dtype
15 if not isinstance(a, numpy.dtype) and a not in {
16 numpy.int8, numpy.uint8, numpy.float16, numpy.float32,
17 numpy.float64, numpy.int32, numpy.int64, numpy.int16,
18 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_,
19 numpy.uint64, bool, str, }:
20 raise TypeError( # pragma: no cover
21 "Type ({}, {}) is not a numpy type (operator 'Constant')".format(
22 a, type(a)))
25class Constant_9(OpRun):
27 atts = {'value': numpy.array([0], dtype=numpy.float32)}
29 def __init__(self, onnx_node, desc=None, **options):
30 OpRun.__init__(self, onnx_node, desc=desc,
31 expected_attributes=Constant_9.atts,
32 **options)
33 self.cst = self.value
34 _check_dtype(self.cst)
36 def _run(self): # pylint: disable=W0221
37 return (self.cst, )
39 def _infer_shapes(self): # pylint: disable=W0221
40 # pref = str(hex(id(self))[2:])
41 return (ShapeObject(self.cst.shape, self.cst.dtype), )
43 def _infer_types(self): # pylint: disable=W0221
44 # pref = str(hex(id(self))[2:])
45 return (self.cst.dtype, )
47 def _infer_sizes(self, *args, **kwargs):
48 res = self.run(*args, **kwargs)
49 return (dict(temp=0), ) + res
52class Constant_11(OpRun):
54 atts = {'value': numpy.array([0], dtype=numpy.float32),
55 'sparse_value': None, }
57 def __init__(self, onnx_node, desc=None, **options):
58 OpRun.__init__(self, onnx_node, desc=desc,
59 expected_attributes=Constant_11.atts,
60 **options)
61 if getattr(self, 'sparse_value', None) is not None:
62 self.cst = self.sparse_value
63 else:
64 self.cst = self.value
65 _check_dtype(self.cst)
67 def _run(self): # pylint: disable=W0221
68 return (self.cst, )
70 def _infer_shapes(self): # pylint: disable=W0221
71 # pref = str(hex(id(self))[2:])
72 return (ShapeObject(self.cst.shape, self.cst.dtype), )
74 def _infer_types(self): # pylint: disable=W0221
75 # pref = str(hex(id(self))[2:])
76 return (self.cst.dtype, )
78 def _infer_sizes(self, *args, **kwargs):
79 res = self.run(*args, **kwargs)
80 return (dict(temp=0), ) + res
83class Constant_12(OpRun):
85 atts = {'value': numpy.array([0], dtype=numpy.float32),
86 'sparse_value': None,
87 'value_float': None,
88 'value_floats': None,
89 'value_int': None,
90 'value_ints': None,
91 'value_string': None,
92 'value_strings': None,
93 }
95 def __init__(self, onnx_node, desc=None, **options):
96 OpRun.__init__(self, onnx_node, desc=desc,
97 expected_attributes=Constant_12.atts,
98 **options)
99 if hasattr(self, 'sparse_value') and self.sparse_value is not None:
100 self.cst = self.sparse_value
101 elif hasattr(self, 'value_float') and self.value_float is not None:
102 self.cst = self.value_float.astype(numpy.float32)
103 elif hasattr(self, 'value_floats') and self.value_floats is not None:
104 self.cst = self.value_floats.astype(numpy.float32)
105 elif hasattr(self, 'value_int') and self.value_int is not None:
106 self.cst = self.value_int.astype(numpy.int64)
107 elif hasattr(self, 'value_ints') and self.value_ints is not None:
108 self.cst = self.value_ints.astype(numpy.int64)
109 elif hasattr(self, 'value_string') and self.value_string is not None:
110 self.cst = self.value_string
111 elif hasattr(self, 'value_strings') and self.value_strings is not None:
112 self.cst = self.value_strings
113 elif hasattr(self, 'value') and self.value is not None:
114 self.cst = self.value
115 else:
116 raise AttributeError( # pragma: no cover
117 "No constant is defined for operator 'Constant'.")
118 _check_dtype(self.cst)
120 def _run(self): # pylint: disable=W0221
121 return (self.cst, )
123 def _infer_shapes(self): # pylint: disable=W0221
124 # pref = str(hex(id(self))[2:])
125 return (ShapeObject(self.cst.shape, self.cst.dtype), )
127 def _infer_types(self): # pylint: disable=W0221
128 # pref = str(hex(id(self))[2:])
129 return (self.cst.dtype, )
131 def _infer_sizes(self, *args, **kwargs):
132 res = self.run(*args, **kwargs)
133 return (dict(temp=0), ) + res
136if onnx_opset_version() >= 12:
137 Constant = Constant_12
138elif onnx_opset_version() >= 11: # pragma: no cover
139 Constant = Constant_11
140else: # pragma: no cover
141 Constant = Constant_9