Coverage for mlprodict/onnxrt/backend.py: 80%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief ONNX Backend for @see cl OnnxInference.
5::
7 import unittest
8 from onnx.backend.test import BackendTest
9 backend_test = BackendTest(backend, __name__)
10 back_test.include('.*add.*')
11 globals().update(backend_test.enable_report().test_cases)
12 unittest.main()
13"""
14from io import BytesIO
15import unittest
16import numpy
17from onnx import version, load as onnx_load
18from onnx.checker import check_model
19from onnx.backend.base import Backend, BackendRep
20from .onnx_inference import OnnxInference
21from .onnx_micro_runtime import OnnxMicroRuntime
22from .onnx_shape_inference import OnnxShapeInference
25class _CombineModels:
27 def __init__(self, onnx_inference, shape_inference):
28 self.onnx_inference = onnx_inference
29 self.shape_inference = shape_inference
31 @property
32 def input_names(self):
33 "Returns the input names."
34 return self.onnx_inference.input_names
36 @property
37 def output_names(self):
38 "Returns the output names."
39 return self.onnx_inference.output_names
41 def run(self, inputs, **kwargs):
42 "Runs shape inferance and onnx inference."
43 shapes = self.shape_inference.run(**kwargs)
44 results = self.onnx_inference.run(inputs, **kwargs)
45 for k, v in results.items():
46 if not shapes[k].is_compatible(v):
47 raise RuntimeError(
48 "Incompatible shapes %r and %r for output %r." % (
49 shapes[k], v.shape, k))
50 return results
53class OnnxInferenceBackendRep(BackendRep):
54 """
55 Computes the prediction for an ONNX graph
56 loaded with @see cl OnnxInference.
58 :param session: @see cl OnnxInference
59 """
61 def __init__(self, session):
62 self._session = session
64 def run(self, inputs, **kwargs): # type: (Any, **Any) -> Tuple[Any, ...]
65 """
66 Computes the prediction. See @see meth OnnxInference.run.
67 """
68 if isinstance(inputs, list):
69 feeds = {}
70 for i, inp in enumerate(self._session.input_names):
71 feeds[inp] = inputs[i]
72 elif isinstance(inputs, dict):
73 feeds = inputs
74 elif isinstance(inputs, numpy.ndarray):
75 names = self._session.input_names
76 if len(names) != 1:
77 raise RuntimeError( # pragma: no cover
78 "Expecting one input not %d." % len(names))
79 feeds = {names[0]: inputs}
80 else:
81 raise TypeError( # pragma: no cover
82 "Unexpected input type %r." % type(inputs))
83 outs = self._session.run(feeds)
84 output_names = self._session.output_names
85 if output_names is None and hasattr(self._session, 'expected_outputs'):
86 output_names = [n[0] for n in self._session.expected_outputs]
87 if output_names is None:
88 raise RuntimeError( # pragma: no cover
89 "output_names cannot be None for type %r." % type(self._session))
90 return [outs[name] for name in output_names]
93class OnnxInferenceBackend(Backend):
94 """
95 ONNX backend following the pattern from
96 `onnx/backend/base.py
97 <https://github.com/onnx/onnx/blob/main/onnx/backend/base.py>`_.
98 This backend can be ran through the following code:
100 ::
102 import unittest
103 from contextlib import redirect_stdout, redirect_stderr
104 from io import StringIO
105 from onnx.backend.test import BackendTest
106 import mlprodict.onnxrt.backend_py as backend
108 back_test = BackendTest(backend, __name__)
109 back_test.exclude('.*_blvc_.*')
110 back_test.exclude('.*_densenet_.*')
111 back_test.exclude('.*_densenet121_.*')
112 back_test.exclude('.*_inception_.*')
113 back_test.exclude('.*_resnet50_.*')
114 back_test.exclude('.*_shufflenet_.*')
115 back_test.exclude('.*_squeezenet_.*')
116 back_test.exclude('.*_vgg19_.*')
117 back_test.exclude('.*_zfnet512_.*')
118 globals().update(back_test.enable_report().test_cases)
119 buffer = StringIO()
120 print('---------------------------------')
122 if True:
123 with redirect_stdout(buffer):
124 with redirect_stderr(buffer):
125 res = unittest.main(verbosity=2, exit=False)
126 else:
127 res = unittest.main(verbosity=2, exit=False)
129 testsRun = res.result.testsRun
130 errors = len(res.result.errors)
131 skipped = len(res.result.skipped)
132 unexpectedSuccesses = len(res.result.unexpectedSuccesses)
133 expectedFailures = len(res.result.expectedFailures)
134 print('---------------------------------')
135 print("testsRun=%d errors=%d skipped=%d unexpectedSuccesses=%d "
136 "expectedFailures=%d" % (
137 testsRun, errors, skipped, unexpectedSuccesses,
138 expectedFailures))
139 ran = testsRun - skipped
140 print("ratio=%f" % (1 - errors * 1.0 / ran))
141 print('---------------------------------')
142 print(buffer.getvalue())
143 """
145 @classmethod
146 def is_compatible(cls, model, device=None, **kwargs):
147 """
148 Returns whether the model is compatible with the backend.
150 :param model: unused
151 :param device: None to use the default device or a string (ex: `'CPU'`)
152 :return: boolean
153 """
154 return device is None or device == 'CPU'
156 @classmethod
157 def is_opset_supported(cls, model):
158 """
159 Returns whether the opset for the model is supported by the backend.
161 :param model: Model whose opsets needed to be verified.
162 :return: boolean and error message if opset is not supported.
163 """
164 return True, ''
166 @classmethod
167 def supports_device(cls, device):
168 """
169 Checks whether the backend is compiled with particular
170 device support.
171 """
172 return device == 'CPU'
174 @classmethod
175 def create_inference_session(cls, model):
176 """
177 Instantiates an instance of class @see cl OnnxInference.
178 This method should be overwritten to change the runtime
179 or any other runtime options.
180 """
181 return OnnxInference(model)
183 @classmethod
184 def prepare(cls, model, device=None, **kwargs):
185 """
186 Loads the model and creates @see cl OnnxInference.
188 :param model: ModelProto (returned by `onnx.load`),
189 string for a filename or bytes for a serialized model
190 :param device: requested device for the computation,
191 None means the default one which depends on
192 the compilation settings
193 :param kwargs: see @see cl OnnxInference
194 :return: see @see cl OnnxInference
195 """
196 if isinstance(model, OnnxInferenceBackendRep):
197 return model
198 if isinstance(model, (OnnxInference, OnnxMicroRuntime,
199 OnnxShapeInference, _CombineModels)):
200 return OnnxInferenceBackendRep(model)
201 if isinstance(model, (str, bytes)):
202 inf = cls.create_inference_session(model)
203 return cls.prepare(inf, device, **kwargs)
204 else:
205 from ..npy.xop_convert import OnnxSubOnnx
206 if isinstance(model, OnnxSubOnnx):
207 return OnnxInferenceBackendRep(model)
209 onnx_version = tuple(map(int, (version.version.split(".")[:3])))
210 onnx_supports_serialized_model_check = onnx_version >= (1, 10, 0)
211 bin_or_model = (
212 model.SerializeToString() if onnx_supports_serialized_model_check
213 else model)
214 check_model(bin_or_model)
215 opset_supported, error_message = cls.is_opset_supported(model)
216 if not opset_supported:
217 raise unittest.SkipTest(error_message) # pragma: no cover
218 binm = bin_or_model
219 if not isinstance(binm, (str, bytes)):
220 binm = binm.SerializeToString()
221 return cls.prepare(binm, device, **kwargs)
223 @classmethod
224 def run_model(cls, model, inputs, device=None, **kwargs):
225 """
226 Computes the prediction.
228 :param model: see @see cl OnnxInference returned by function *prepare*
229 :param inputs: inputs
230 :param device: requested device for the computation,
231 None means the default one which depends on
232 the compilation settings
233 :param kwargs: see @see cl OnnxInference
234 :return: predictions
235 """
236 rep = cls.prepare(model, device, **kwargs)
237 return rep.run(inputs, **kwargs)
239 @classmethod
240 def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
241 '''
242 This method is not implemented as it is much more efficient
243 to run a whole model than every node independently.
244 '''
245 raise NotImplementedError("Unable to run the model node by node.")
248class OnnxInferenceBackendPyC(OnnxInferenceBackend):
249 """
250 Same backend as @see cl OnnxInferenceBackend but runtime
251 is `python_compiled`.
252 """
254 @classmethod
255 def create_inference_session(cls, model):
256 return OnnxInference(model, runtime='python_compiled')
259class OnnxInferenceBackendOrt(OnnxInferenceBackend):
260 """
261 Same backend as @see cl OnnxInferenceBackend but runtime
262 is `onnxruntime1`.
263 """
265 @classmethod
266 def create_inference_session(cls, model):
267 return OnnxInference(model, runtime='onnxruntime1')
270class OnnxInferenceBackendMicro(OnnxInferenceBackend):
271 """
272 Same backend as @see cl OnnxInferenceBackend but runtime
273 is @see cl OnnxMicroRuntime.
274 """
276 @classmethod
277 def create_inference_session(cls, model):
278 if isinstance(model, str):
279 with open(model, 'rb') as f:
280 content = onnx_load(f)
281 elif isinstance(model, bytes):
282 content = onnx_load(BytesIO(model))
283 else:
284 content = model
285 return OnnxMicroRuntime(content)
288class OnnxInferenceBackendShape(OnnxInferenceBackend):
289 """
290 Same backend as @see cl OnnxInferenceBackend but runtime
291 is @see cl OnnxShapeInference.
292 """
294 @classmethod
295 def create_inference_session(cls, model):
296 if isinstance(model, str):
297 with open(model, 'rb') as f:
298 content = onnx_load(f)
299 elif isinstance(model, bytes):
300 content = onnx_load(BytesIO(model))
301 else:
302 content = model
303 return _CombineModels(OnnxInference(content),
304 OnnxShapeInference(content))
306 @classmethod
307 def run_model(cls, model, inputs, device=None, **kwargs):
308 """
309 Computes the prediction.
311 :param model: see @see cl OnnxShapeInference returned by
312 function *prepare*
313 :param inputs: inputs
314 :param device: requested device for the computation,
315 None means the default one which depends on
316 the compilation settings
317 :param kwargs: see @see cl OnnxInference
318 :return: predictions
319 """
320 rep = cls.prepare(model, device, **kwargs)
321 shapes = rep.shape_inference.run(**kwargs)
322 results = rep.onnx_inference.run(inputs, **kwargs)
323 for k, v in results.items():
324 if not shapes[k].is_compatible(v):
325 raise RuntimeError( # pragma: no cover
326 "Incompatible shapes %r and %r for output %r." % (
327 shapes[k], v.shape, k))
328 return results
331class OnnxInferenceBackendPyEval(OnnxInferenceBackend):
332 """
333 Same backend as @see cl OnnxInferenceBackend but runtime
334 is @see cl OnnxShapeInference.
335 """
337 @classmethod
338 def create_inference_session(cls, model):
339 from ..npy.xop_convert import OnnxSubOnnx
340 if isinstance(model, str):
341 with open(model, 'rb') as f:
342 content = onnx_load(f)
343 elif isinstance(model, bytes):
344 content = onnx_load(BytesIO(model))
345 else:
346 content = model
347 return OnnxSubOnnx(content)
349 @classmethod
350 def run_model(cls, model, inputs, device=None, **kwargs):
351 """
352 Computes the prediction.
354 :param model: see @see cl OnnxShapeInference returned by
355 function *prepare*
356 :param inputs: inputs
357 :param device: requested device for the computation,
358 None means the default one which depends on
359 the compilation settings
360 :param kwargs: see @see cl OnnxInference
361 :return: predictions
362 """
363 rep = cls.prepare(model, device, **kwargs)
364 shapes = rep.shape_inference.run(**kwargs)
365 results = rep.onnx_inference.run(inputs, **kwargs)
366 for k, v in results.items():
367 if not shapes[k].is_compatible(v):
368 raise RuntimeError( # pragma: no cover
369 "Incompatible shapes %r and %r for output %r." % (
370 shapes[k], v.shape, k))
371 return results