Coverage for mlprodict/testing/test_utils/utils_backend_onnxruntime.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Inspired from sklearn-onnx, handles two backends.
4"""
5from pyquickhelper.pycode import is_travis_or_appveyor
6from .utils_backend_common_compare import compare_runtime_session
9def _capture_output(fct, kind):
10 if is_travis_or_appveyor():
11 return fct(), None, None # pragma: no cover
12 try:
13 from cpyquickhelper.io import capture_output
14 except ImportError: # pragma: no cover
15 # cpyquickhelper not available
16 return fct(), None, None
17 return capture_output(fct, kind)
20class InferenceSession2:
21 """
22 Overwrites class *InferenceSession* to capture
23 the standard output and error.
24 """
26 def __init__(self, *args, **kwargs):
27 "Overwrites the constructor."
28 from onnxruntime import (
29 InferenceSession, GraphOptimizationLevel, SessionOptions)
30 runtime_options = kwargs.pop('runtime_options', {})
31 disable_optimisation = runtime_options.pop(
32 'disable_optimisation', False)
33 if disable_optimisation:
34 if 'sess_options' in kwargs:
35 raise RuntimeError( # pragma: no cover
36 "Incompatible options, 'disable_options' and 'sess_options' cannot "
37 "be sepcified at the same time.")
38 kwargs['sess_options'] = SessionOptions()
39 kwargs['sess_options'].graph_optimization_level = (
40 GraphOptimizationLevel.ORT_DISABLE_ALL)
41 if 'providers' not in kwargs:
42 kwargs = kwargs.copy()
43 kwargs['providers'] = ['CPUExecutionProvider']
44 self.sess, self.outi, self.erri = _capture_output(
45 lambda: InferenceSession(*args, **kwargs), 'c')
47 def run(self, *args, **kwargs):
48 "Overwrites method *run*."
49 res, self.outr, self.errr = _capture_output(
50 lambda: self.sess.run(*args, **kwargs), 'c')
51 return res
53 def get_inputs(self, *args, **kwargs):
54 "Overwrites method *get_inputs*."
55 return self.sess.get_inputs(*args, **kwargs)
57 def get_outputs(self, *args, **kwargs):
58 "Overwrites method *get_outputs*."
59 return self.sess.get_outputs(*args, **kwargs)
62def compare_runtime(test, decimal=5, options=None,
63 verbose=False, context=None, comparable_outputs=None,
64 intermediate_steps=False, classes=None,
65 disable_optimisation=False):
66 """
67 The function compares the expected output (computed with
68 the model before being converted to ONNX) and the ONNX output
69 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`.
71 :param test: dictionary with the following keys:
72 - *onnx*: onnx model (filename or object)
73 - *expected*: expected output (filename pkl or object)
74 - *data*: input data (filename pkl or object)
75 :param decimal: precision of the comparison
76 :param options: comparison options
77 :param context: specifies custom operators
78 :param verbose: in case of error, the function may print
79 more information on the standard output
80 :param comparable_outputs: compare only these outputs
81 :param intermediate_steps: displays intermediate steps
82 in case of an error
83 :param classes: classes names (if option 'nocl' is used)
84 :param disable_optimisation: disable optimisation onnxruntime
85 could do
86 :return: tuple (outut, lambda function to run the predictions)
88 The function does not return anything but raises an error
89 if the comparison failed.
90 """
91 return compare_runtime_session(
92 InferenceSession2, test, decimal=decimal, options=options,
93 verbose=verbose, context=context,
94 comparable_outputs=comparable_outputs,
95 intermediate_steps=intermediate_steps,
96 classes=classes, disable_optimisation=disable_optimisation)