Coverage for mlprodict/onnxrt/ops_whole/session.py: 96%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *ops_whole*.
5"""
6import json
7from io import BytesIO
8import numpy
9import onnx
12class OnnxWholeSession:
13 """
14 Runs the prediction for a single :epkg:`ONNX`,
15 it lets the runtime handle the graph logic as well.
17 :param onnx_data: :epkg:`ONNX` model or data
18 :param runtime: runtime to be used, mostly :epkg:`onnxruntime`
19 :param runtime_options: runtime options
20 :param device: device, a string `cpu`, `cuda`, `cuda:0`...
22 .. versionchanged:: 0.8
23 Parameter *device* was added.
24 """
26 def __init__(self, onnx_data, runtime, runtime_options=None, device=None):
27 if runtime not in ('onnxruntime1', 'onnxruntime1-cuda'):
28 raise NotImplementedError( # pragma: no cover
29 "runtime '{}' is not implemented.".format(runtime))
31 from onnxruntime import ( # delayed
32 InferenceSession, SessionOptions, RunOptions,
33 GraphOptimizationLevel)
34 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
35 Fail as OrtFail, InvalidGraph as OrtInvalidGraph,
36 InvalidArgument as OrtInvalidArgument,
37 NotImplemented as OrtNotImplemented,
38 RuntimeException as OrtRuntimeException)
40 if hasattr(onnx_data, 'SerializeToString'):
41 onnx_data = onnx_data.SerializeToString()
42 if isinstance(runtime_options, SessionOptions):
43 sess_options = runtime_options
44 session_options = None
45 runtime_options = None
46 else:
47 session_options = (
48 None if runtime_options is None
49 else runtime_options.get('session_options', None))
50 self.runtime = runtime
51 sess_options = session_options or SessionOptions()
52 self.run_options = RunOptions()
53 self.run_options.log_severity_level = 3
54 self.run_options.log_verbosity_level = 1
56 if session_options is None:
57 if runtime_options is not None:
58 if runtime_options.get('disable_optimisation', False):
59 sess_options.graph_optimization_level = ( # pragma: no cover
60 GraphOptimizationLevel.ORT_ENABLE_ALL)
61 if runtime_options.get('enable_profiling', True):
62 sess_options.enable_profiling = True
63 if runtime_options.get('log_severity_level', 2) != 2:
64 v = runtime_options.get('log_severity_level', 2)
65 sess_options.log_severity_level = v
66 self.run_options.log_severity_level = v
67 elif runtime_options is not None and 'enable_profiling' in runtime_options:
68 raise RuntimeError( # pragma: no cover
69 "session_options and enable_profiling cannot be defined at the "
70 "same time.")
71 elif runtime_options is not None and 'disable_optimisation' in runtime_options:
72 raise RuntimeError( # pragma: no cover
73 "session_options and disable_optimisation cannot be defined at the "
74 "same time.")
75 elif runtime_options is not None and 'log_severity_level' in runtime_options:
76 raise RuntimeError( # pragma: no cover
77 "session_options and log_severity_level cannot be defined at the "
78 "same time.")
79 providers = ['CPUExecutionProvider']
80 if runtime == 'onnxruntime1-cuda':
81 providers = ['CUDAExecutionProvider'] + providers
82 try:
83 self.sess = InferenceSession(onnx_data, sess_options=sess_options,
84 device=device, providers=providers)
85 except (OrtFail, OrtNotImplemented, OrtInvalidGraph,
86 OrtInvalidArgument, OrtRuntimeException, RuntimeError) as e:
87 from ...tools.asv_options_helper import display_onnx
88 raise RuntimeError(
89 "Unable to create InferenceSession due to '{}'\n{}.".format(
90 e, display_onnx(onnx.load(BytesIO(onnx_data))))) from e
91 self.output_names = [_.name for _ in self.sess.get_outputs()]
93 def run(self, inputs):
94 """
95 Computes the predictions.
97 @param inputs dictionary *{variable, value}*
98 @return list of outputs
99 """
100 v = next(iter(inputs.values()))
101 if isinstance(v, (numpy.ndarray, dict)):
102 try:
103 return self.sess._sess.run(
104 self.output_names, inputs, self.run_options)
105 except ValueError as e:
106 raise ValueError(
107 "Issue running inference inputs=%r, expected inputs=%r."
108 "" % (
109 list(sorted(inputs)),
110 [i.name for i in self.sess.get_inputs()])) from e
111 try:
112 return self.sess._sess.run_with_ort_values(
113 inputs, self.output_names, self.run_options)
114 except RuntimeError:
115 return self.sess._sess.run_with_ort_values(
116 {k: v._get_c_value() for k, v in inputs.items()},
117 self.output_names, self.run_options)
119 @staticmethod
120 def process_profiling(js):
121 """
122 Flattens json returned by onnxruntime profiling.
124 :param js: json
125 :return: list of dictionaries
126 """
127 rows = []
128 for row in js:
129 if 'args' in row and isinstance(row['args'], dict):
130 for k, v in row['args'].items():
131 row['args_%s' % k] = v
132 del row['args']
133 rows.append(row)
134 return rows
136 def get_profiling(self):
137 """
138 Returns the profiling informations.
139 """
140 prof = self.sess.end_profiling()
141 with open(prof, 'r') as f:
142 content = f.read()
143 js = json.loads(content)
144 return OnnxWholeSession.process_profiling(js)