Coverage for mlprodict/tools/ort_wrapper.py: 88%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Wrapper around :epkg:`onnxruntime`.
5.. versionadded:: 0.6
6"""
7import os
8from onnx import numpy_helper
11class InferenceSession: # pylint: disable=E0102
12 """
13 Wrappers around InferenceSession from :epkg:`onnxruntime`.
15 :param onnx_bytes: onnx bytes
16 :param session_options: session options
17 :param log_severity_level: change the logging level
18 :param runtime: runtime to use, `onnxruntime`, `onnxruntime-cuda`, ...
19 :param providers: providers
20 """
22 def __init__(self, onnx_bytes, sess_options=None, log_severity_level=4,
23 runtime='onnxruntime', providers=None):
24 from onnxruntime import ( # pylint: disable=W0611
25 SessionOptions, RunOptions,
26 InferenceSession as OrtInferenceSession,
27 set_default_logger_severity)
28 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
29 OrtValue as C_OrtValue)
31 self.C_OrtValue = C_OrtValue
33 self.log_severity_level = log_severity_level
34 if providers is not None:
35 self.providers = providers
36 elif runtime in (None, 'onnxruntime', 'onnxruntime1', 'onnxruntime2'):
37 providers = ['CPUExecutionProvider']
38 elif runtime in ('onnxruntime-cuda', 'onnxruntime1-cuda', 'onnxruntime2-cuda'):
39 providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
40 else:
41 raise ValueError(
42 "Unexpected value %r for onnxruntime." % (runtime, ))
43 self.providers = providers
44 set_default_logger_severity(3)
45 if sess_options is None:
46 self.so = SessionOptions()
47 self.so.log_severity_level = log_severity_level
48 self.sess = OrtInferenceSession(
49 onnx_bytes, sess_options=self.so,
50 providers=self.providers)
51 else:
52 self.so = sess_options
53 self.sess = OrtInferenceSession(
54 onnx_bytes, sess_options=sess_options,
55 providers=self.providers)
56 self.ro = RunOptions()
57 self.ro.log_severity_level = log_severity_level
58 self.ro.log_verbosity_level = log_severity_level
59 self.output_names = [o.name for o in self.get_outputs()]
61 def run(self, output_names, input_feed, run_options=None):
62 """
63 Executes the ONNX graph.
65 :param output_names: None for all, a name for a specific output
66 :param input_feed: dictionary of inputs
67 :param run_options: None or RunOptions
68 :return: array
69 """
70 if any(map(lambda v: isinstance(v, self.C_OrtValue),
71 input_feed.values())):
72 return self.sess._sess.run_with_ort_values(
73 input_feed, self.output_names, run_options or self.ro)
74 return self.sess.run(output_names, input_feed, run_options or self.ro)
76 def get_inputs(self):
77 "Returns input types."
78 return self.sess.get_inputs()
80 def get_outputs(self):
81 "Returns output types."
82 return self.sess.get_outputs()
84 def end_profiling(self):
85 "Ends profiling."
86 return self.sess.end_profiling()
89def prepare_c_profiling(model_onnx, inputs, dest=None):
90 """
91 Prepares model and data to be profiled with tool `perftest
92 <https://github.com/microsoft/onnxruntime/tree/
93 master/onnxruntime/test/perftest>`_ (onnxruntime) or
94 `onnx_test_runner <https://github.com/microsoft/
95 onnxruntime/blob/master/docs/Model_Test.md>`_.
96 It saves the model in folder
97 *dest* and dumps the inputs in a subfolder.
99 :param model_onnx: onnx model
100 :param inputs: inputs as a list of a dictionary
101 :param dest: destination folder, None means the current folder
102 :return: command line to use
103 """
104 if dest is None:
105 dest = "."
106 if not os.path.exists(dest):
107 os.makedirs(dest) # pragma: no cover
108 dest = os.path.abspath(dest)
109 name = "model.onnx"
110 model_bytes = model_onnx.SerializeToString()
111 with open(os.path.join(dest, name), "wb") as f:
112 f.write(model_bytes)
113 sess = InferenceSession(model_bytes, providers=['CPUExecutionProvider'])
114 input_names = [_.name for _ in sess.get_inputs()]
115 if isinstance(inputs, list):
116 dict_inputs = dict(zip(input_names, inputs))
117 else:
118 dict_inputs = inputs
119 inputs = [dict_inputs[n] for n in input_names]
120 outputs = sess.run(None, dict_inputs)
121 sub = os.path.join(dest, "test_data_set_0")
122 if not os.path.exists(sub):
123 os.makedirs(sub)
124 for i, v in enumerate(inputs):
125 n = os.path.join(sub, "input_%d.pb" % i)
126 pr = numpy_helper.from_array(v)
127 with open(n, "wb") as f:
128 f.write(pr.SerializeToString())
129 for i, v in enumerate(outputs):
130 n = os.path.join(sub, "output_%d.pb" % i)
131 pr = numpy_helper.from_array(v)
132 with open(n, "wb") as f:
133 f.write(pr.SerializeToString())
135 cmd = 'onnx_test_runner -e cpu -r 100 -c 1 "%s"' % dest
136 return cmd