Coverage for mlprodict/onnxrt/onnx_micro_runtime.py: 91%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Micro runtime for ONNX.
5.. versionadded:: 0.6
6"""
7import numpy
8from ..onnx_tools.onnx2py_helper import _var_as_dict
11class OnnxMicroRuntime:
12 """
13 Implements a micro runtime for ONNX graphs.
14 It does not implements all the operator types.
16 :param model_onnx: ONNX model
18 .. runpython::
19 :showcode:
21 import pprint
22 import numpy
23 from mlprodict.onnxrt.onnx_micro_runtime import OnnxMicroRuntime
24 from mlprodict.npy.xop import loadop
26 OnnxAdd = loadop('Add')
28 dtype = numpy.float32
29 opset = 15
30 x = numpy.array([1, 2, 4, 5, 5, 4]).astype(
31 numpy.float32).reshape((3, 2))
32 cop = OnnxAdd('X', numpy.array([1], dtype=dtype), op_version=opset)
33 cop4 = OnnxAdd(cop, numpy.array([2], dtype=dtype), op_version=opset,
34 output_names=['Y'])
35 model_def = cop4.to_onnx({'X': x}, target_opset=opset)
36 rt = OnnxMicroRuntime(model_def)
37 out = rt.run({'X': x})
38 pprint.pprint(out)
39 """
41 def __init__(self, model_onnx):
42 if not hasattr(model_onnx, 'graph'):
43 raise TypeError(
44 "model_onnx is not an ONNX graph but %r." % type(model_onnx))
45 self.model_onnx = model_onnx
47 @property
48 def input_names(self):
49 "Returns input names."
50 return [i.name for i in self.model_onnx.graph.input]
52 @property
53 def output_names(self):
54 "Returns output names."
55 return [i.name for i in self.model_onnx.graph.output]
57 def run(self, inputs):
58 """
59 Computes the outputs of the graph.
61 :param inputs: dictionary
62 :return: all intermediates results and output as a dictionary
63 """
64 if not isinstance(inputs, dict):
65 raise TypeError(
66 "inputs must be a dictionary not %r." % type(inputs))
67 results = inputs.copy()
69 for init in self.model_onnx.graph.initializer:
70 name = init.name
71 mat = _var_as_dict(init)['value']
72 results[name] = mat
74 for node in self.model_onnx.graph.node:
75 op_type = node.op_type
76 inp = [results[n] for n in node.input]
77 meth_name = "_op_%s" % op_type.lower()
78 if not hasattr(self, meth_name):
79 raise NotImplementedError(
80 "OnnxMicroRuntime does not implement operator %r." % op_type)
81 kwargs = {}
82 for at in node.attribute:
83 var = _var_as_dict(at)
84 kwargs[at.name] = var['value']
85 out = getattr(self, meth_name)(*inp, **kwargs)
86 for n, o in zip(node.output, out):
87 results[n] = o
89 return results
91 ########################
92 # Runtime for operators
93 ########################
95 def _op_abs(self, x):
96 "Runtime for operator :epkg:`Op:Abs`."
97 return (numpy.abs(x), )
99 def _op_add(self, x, y):
100 "Runtime for operator :epkg:`Op:Add`."
101 return (x + y, )
103 def _op_concat(self, *args, axis=None):
104 "Runtime for operator :epkg:`Op:Concat`."
105 def _preprocess(a, axis):
106 if axis >= len(a.shape):
107 new_shape = a.shape + (1, ) * (axis + 1 - len(a.shape))
108 return a.reshape(new_shape)
109 return a
111 targs = tuple(_preprocess(a, axis) for a in args)
112 return (numpy.concatenate(targs, axis), )
114 def _op_gemm(self, a, b, c=None, alpha=None, beta=None,
115 transA=False, transB=False):
116 "Runtime for operator :epkg:`Op:Gemm`."
118 def _gemm00(a, b, c, alpha, beta):
119 o = numpy.dot(a, b) * alpha
120 if beta != 0:
121 o += c * beta
122 return o
124 def _gemm01(a, b, c, alpha, beta):
125 o = numpy.dot(a, b.T) * alpha
126 if beta != 0:
127 o += c * beta
128 return o
130 def _gemm10(a, b, c, alpha, beta):
131 o = numpy.dot(a.T, b) * alpha
132 if beta != 0:
133 o += c * beta
134 return o
136 def _gemm11(a, b, c, alpha, beta):
137 o = numpy.dot(a.T, b.T) * alpha
138 if beta != 0:
139 o += c * beta
140 return o
142 if not isinstance(transA, (int, bool, numpy.int64)):
143 raise TypeError( # pragma: no cover
144 "Unexpected type for transA: %r." % type(transA))
145 if not isinstance(transB, (int, bool, numpy.int64)):
146 raise TypeError( # pragma: no cover
147 "Unexpected type for transA: %r." % type(transB))
148 if transA:
149 fct = _gemm11 if transB else _gemm10
150 else:
151 fct = _gemm01 if transB else _gemm00
152 return (fct(a, b, c, alpha=alpha, beta=beta), )
154 def _op_gather(self, x, indices, axis=None):
155 "Runtime for operator :epkg:`Op:Gather`."
156 if not x.flags['C_CONTIGUOUS']:
157 x = numpy.ascontiguousarray(x)
158 if not indices.flags['C_CONTIGUOUS']:
159 indices = indices.ascontiguousarray()
160 return (numpy.take(x, indices, axis=axis), )
162 def _op_identity(self, x):
163 "Runtime for operator :epkg:`Op:Identity`."
164 return (x, )
166 def _op_matmul(self, x, y):
167 "Runtime for operator :epkg:`Op:MatMul`."
168 return (numpy.matmul(x, y), )
170 def _op_max(self, *inps):
171 "Runtime for operator :epkg:`Op:Max`."
172 return (numpy.maximum(*inps), )
174 def _op_mul(self, x, y):
175 "Runtime for operator :epkg:`Op:Mul`."
176 return (x * y, )
178 def _op_reduceprod(self, data, axes=None, keepdims=None):
179 "Runtime for operator :epkg:`Op:ReduceProd`."
180 if axes is not None and not isinstance(axes, int):
181 if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0:
182 axes = int(axes)
183 else:
184 axes = tuple(axes) if len(axes) > 0 else None
185 return (numpy.prod(data, axis=axes,
186 keepdims=keepdims,
187 dtype=data.dtype), )
189 def _op_reducesum(self, data, axes, keepdims=None,
190 noop_with_empty_axes=None):
191 "Runtime for operator :epkg:`Op:ReduceSum`."
192 if axes is None and noop_with_empty_axes:
193 return (data, )
194 if axes is not None and not isinstance(axes, int):
195 if isinstance(axes, numpy.ndarray) and len(axes.shape) == 0:
196 axes = int(axes)
197 else:
198 axes = tuple(axes) if len(axes) > 0 else None
199 return (numpy.sum(data, axis=axes,
200 keepdims=keepdims,
201 dtype=data.dtype), )
203 def _op_reshape(self, x, shape):
204 "Runtime for operator :epkg:`Op:Reshape`."
205 return (x.reshape(shape), )
207 def _op_shape(self, x):
208 "Runtime for operator :epkg:`Op:Shape`."
209 return (numpy.array(list(x.shape), dtype=numpy.int64), )
211 def _op_squeeze(self, x, axes=None):
212 "Runtime for operator :epkg:`Op:Squeeze`."
213 if axes is None:
214 return (x, )
215 if hasattr(axes, '__iter__'):
216 return (numpy.squeeze(x, axis=tuple(axes)), )
217 return (numpy.squeeze(x, axis=axes), )
219 def _op_transpose(self, x, perm=None):
220 "Runtime for operator :epkg:`Op:Transpose`."
221 return (numpy.transpose(x, perm), )
223 def _op_unsqueeze(self, x, axes=None):
224 "Runtime for operator :epkg:`Op:Unsqueeze`."
225 if axes is None:
226 return (x, )
227 if hasattr(axes, '__iter__'):
228 return (numpy.expand_dims(x, axis=tuple(axes)), )
229 return (numpy.expand_dims(x, axis=axes), )