Coverage for mlprodict/onnxrt/ops_onnxruntime/_op.py: 96%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *ops_onnxruntime*.
5"""
6import numpy
7import onnx.defs
8from onnx.helper import make_tensor
9from onnx.onnx_cpp2py_export.shape_inference import InferenceError # pylint: disable=E0401,E0611
10from ...tools.ort_wrapper import InferenceSession
11from ...onnx_tools.onnx2py_helper import guess_proto_dtype
12from ...onnx_tools.optim.graph_schema_helper import (
13 get_defined_inputs, get_defined_outputs, proto2vars)
16_schemas = {
17 schema.name: schema for schema in onnx.defs.get_all_schemas_with_history()}
20class OpRunOnnxRuntime:
21 """
22 Unique operator which calls :epkg:`onnxruntime`
23 to compute predictions for one operator.
24 """
26 def __init__(self, onnx_node, desc=None, variables=None,
27 dtype=None, runtime=None, **options):
28 """
29 :param onnx_node: :epkg:`onnx` node
30 :param desc: internal representation
31 :param variables: registered variables created by previous operators
32 :param dtype: float computation type
33 :param options: runtime options
34 :param runtime: `onnxruntime1`, `onnxruntime1-cuda`, ...
35 """
36 self._provider = 'onnxruntime'
37 self.onnx_node = onnx_node
38 self.desc = desc
39 self.runtime = runtime
40 self._schema = _schemas.get(onnx_node.op_type, None)
41 if desc is not None:
42 if 'atts' in desc:
43 for a, b in desc['atts'].items():
44 if not isinstance(b, dict) or 'value' not in b:
45 raise ValueError( # pragma: no cover
46 "Unexpected value {}.".format(b))
47 options[a] = b['value']
49 self.options = options
50 self.dtype = dtype
51 self._init(variables)
53 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
54 InvalidArgument as OrtInvalidArgument)
55 self.OrtInvalidArgument = OrtInvalidArgument
57 def _name_mapping(self, inputs):
58 mapping = {}
59 new_inputs = []
60 for name in inputs:
61 if name in mapping:
62 i = 0
63 new_name = "{}_{}".format(name, i)
64 while new_name in mapping:
65 i += 1 # pragma: no cover
66 new_name = "{}_{}".format(name, i) # pragma: no cover
67 mapping[new_name] = name
68 new_inputs.append(new_name)
69 else:
70 new_inputs.append(name)
71 mapping[name] = name
72 return mapping, new_inputs
74 def _guess_proto_type(self, dtype):
75 return guess_proto_dtype(dtype)
77 def _init(self, variables=None):
78 """
79 Initializes the node.
81 :param variables: registered variables created by previous operators
83 The current implementation for operator *Scan*
84 only works for matrices.
85 """
86 custom_nodes = self.options.get('nodes', None)
87 if (custom_nodes is not None and
88 self.onnx_node.op_type in custom_nodes):
89 self.alg_class = custom_nodes[self.onnx_node.op_type]
90 else:
91 try:
92 import mlprodict.onnx_conv.onnx_ops as alg0
93 self.alg_class = getattr(alg0, 'Onnx' + self.onnx_node.op_type)
94 except AttributeError:
95 import skl2onnx.algebra.custom_ops as alg2 # delayed
96 try:
97 self.alg_class = getattr(
98 alg2, 'Onnx' + self.onnx_node.op_type)
99 except AttributeError:
100 import skl2onnx.algebra.onnx_ops as alg # delayed
101 self.alg_class = getattr(
102 alg, 'Onnx' + self.onnx_node.op_type)
104 inputs = list(self.onnx_node.input)
105 self.mapping, self.inputs = self._name_mapping(inputs)
106 self.outputs = list(self.onnx_node.output)
108 options = self.options.copy()
109 options.pop('nodes', None)
110 target_opset = options.pop('target_opset', None)
111 domain = options.pop('domain', None)
112 disable_optimisation = options.pop('disable_optimisation', False)
113 session_options = options.pop('session_options', False)
114 ir_version = options.pop('ir_version', None)
116 if domain == '' and target_opset < 9:
117 # target_opset should be >= 9 not {} for main domain.
118 # We assume it was the case when the graph was created.
119 pass
121 if self.onnx_node.op_type == 'ZipMap':
122 from skl2onnx.common.data_types import ( # delayed
123 DictionaryType, FloatTensorType, Int64TensorType, StringTensorType)
124 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs,
125 op_version=target_opset, **options)
126 inputs = get_defined_inputs(
127 self.inputs, variables, dtype=self.dtype)
128 name = (self.outputs[0] if len(self.outputs) == 1
129 else self.inst_.expected_outputs[0][0])
130 otype = (Int64TensorType if 'classlabels_int64s' in options
131 else StringTensorType)
132 outvar = [(name, DictionaryType(otype([1]), FloatTensorType([1])))]
133 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outvar)
134 forced = True
135 elif self.onnx_node.op_type == 'ConstantOfShape':
136 for k in options:
137 v = options[k]
138 if isinstance(v, numpy.ndarray):
139 options[k] = make_tensor(
140 k, self._guess_proto_type(v.dtype),
141 v.shape, v.tolist())
143 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs,
144 op_version=target_opset, **options)
145 inputs = get_defined_inputs(
146 self.inputs, variables, dtype=self.dtype)
147 try:
148 self.onnx_ = self.inst_.to_onnx(inputs, target_opset=target_opset,
149 domain=domain)
150 if "dim_value: 0" in str(self.onnx_):
151 raise RuntimeError( # pragma: no cover
152 "Probable issue as one dimension is null.\n--\n{}".format(
153 self.onnx_))
154 except AttributeError as e: # pragma: no cover
155 # older version of skl2onnx
156 self.onnx_ = self.inst_.to_onnx(inputs)
157 if "dim_value: 0" in str(self.onnx_):
158 raise RuntimeError(
159 "Probable issue as one dimension is null.\n--\n{}".format(
160 self.onnx_)) from e
161 forced = False
162 elif self.onnx_node.op_type == 'Scan':
163 self.inst_ = self.alg_class(
164 *self.inputs, output_names=self.outputs,
165 op_version=target_opset, **options)
166 inputs = get_defined_inputs(
167 self.inputs, variables, dtype=self.dtype)
168 outputs = get_defined_outputs(
169 self.outputs, self.onnx_node, inputs, variables,
170 dtype=self.dtype)
171 inputs = [(name, cl.__class__([None, None]))
172 for (name, cl) in inputs]
173 outputs = [(name, cl.__class__([None, None]))
174 for (name, cl) in outputs]
175 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs,
176 target_opset=target_opset,
177 domain=domain)
178 if "dim_value: 0" in str(self.onnx_):
179 raise RuntimeError( # pragma: no cover
180 "Probable issue as one dimension is null.\n--\n{}".format(
181 self.onnx_))
182 forced = True
183 else:
184 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs,
185 op_version=target_opset, domain=domain,
186 **options)
187 inputs = get_defined_inputs(
188 self.inputs, variables, dtype=self.dtype,
189 schema=self.alg_class.expected_inputs)
191 try:
192 self.onnx_ = self.inst_.to_onnx(
193 inputs, target_opset=target_opset, domain=domain)
194 if "dim_value: 0" in str(self.onnx_):
195 raise RuntimeError( # pragma: no cover
196 "Probable issue as one dimension is null.\n--\n{}\n---\n{}".format(
197 self.onnx_, inputs))
198 forced = False
199 except (RuntimeError, ValueError, InferenceError) as eo:
200 # Let's try again by forcing output types.
201 forced = True
202 outputs = get_defined_outputs(
203 self.outputs, self.onnx_node, inputs, variables,
204 dtype=self.dtype, schema=self.alg_class.expected_outputs,
205 schema_inputs=self.alg_class.expected_inputs)
206 try:
207 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs,
208 target_opset=target_opset,
209 domain=domain)
210 except NotImplementedError as e: # pragma: no cover
211 raise NotImplementedError(
212 "Unable to instantiate node {} inputs={} "
213 "self.inputs={} outputs={} variables={} "
214 "dtype={} e={} eo={}".format(
215 self.alg_class, inputs, self.inputs,
216 outputs, variables, self.dtype, e, eo)) from e
217 if "dim_value: 0" in str(self.onnx_):
218 raise RuntimeError( # pragma: no cover
219 "Probable issue as one dimension is null.\n--\n{}".format(
220 self.onnx_)) from e
222 if len(self.onnx_.graph.output) != len(self.outputs): # pragma: no cover
223 # Something is wrong, falls back to default plan.
224 forced = True
225 outputs = get_defined_outputs(
226 self.outputs, self.onnx_node, inputs, variables,
227 dtype=self.dtype, schema=self.alg_class.expected_outputs)
228 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs,
229 target_opset=target_opset,
230 domain=domain)
231 if "dim_value: 0" in str(self.onnx_):
232 raise RuntimeError( # pragma: no cover
233 "Probable issue as one dimension is null.\n--\n{}".format(
234 self.onnx_))
235 else:
236 lo = list(self.onnx_.graph.output)
237 outputs = proto2vars(lo)
239 from onnxruntime import ( # pylint: disable=E0611
240 SessionOptions, RunOptions, GraphOptimizationLevel)
241 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
242 Fail as OrtFail, InvalidGraph as OrtInvalidGraph,
243 NotImplemented as OrtNotImplemented)
245 sess_options = session_options or SessionOptions()
246 self.run_options = RunOptions()
248 if session_options is None:
249 try:
250 sess_options.session_log_severity_level = 3
251 # sess_options.sessions_log_verbosity_level = 0
252 except AttributeError: # pragma: no cover
253 # onnxruntime not recent enough.
254 pass
255 try:
256 self.run_options.run_log_severity_level = 3
257 # self.run_options.run_log_verbosity_level = 0
258 except AttributeError: # pragma: no cover
259 # onnxruntime not recent enough.
260 pass
261 if disable_optimisation:
262 sess_options.graph_optimization_level = ( # pragma: no cover
263 GraphOptimizationLevel.ORT_DISABLE_ALL)
264 elif disable_optimisation:
265 raise RuntimeError( # pragma: no cover
266 "session_options and disable_optimisation cannot be defined "
267 "at the same time.")
269 if ir_version is not None:
270 self.onnx_.ir_version = ir_version
271 try:
272 self.sess_ = InferenceSession(
273 self.onnx_.SerializeToString(), sess_options=sess_options,
274 runtime=self.runtime)
275 except (RuntimeError, OrtNotImplemented, OrtInvalidGraph, OrtFail) as e:
276 raise RuntimeError(
277 "Unable to load node '{}' (output type was {}) inputs={} "
278 "self.inputs={} self.onnx_node.input={} "
279 "variables={} mapping={} "
280 "expected_inputs={}\n{}".format(
281 self.onnx_node.op_type,
282 "guessed" if forced else "inferred",
283 inputs, self.inputs, self.onnx_node.input,
284 variables, self.mapping,
285 self.alg_class.expected_inputs,
286 self.onnx_)) from e
287 self.typed_outputs_ = outputs
289 def run(self, *args, **kwargs):
290 """
291 Should be overwritten.
292 """
293 inputs = {name: val for name, val in zip(self.inputs, args)}
295 try:
296 res = self.sess_.run(None, inputs, self.run_options)
297 except (RuntimeError, self.OrtInvalidArgument) as e: # pragma: no cover
298 dtypes = {k: v.dtype for k, v in inputs.items()}
299 shapes = {k: v.shape for k, v in inputs.items()}
300 exp = [_.name for _ in self.sess_.get_inputs()]
301 exp_types = [_.type for _ in self.sess_.get_inputs()]
302 raise RuntimeError(
303 "Predictions failed. List of inputs: {}, class={}"
304 "\ndtypes={}\nshapes={}\nexpected={}\nexpected={}\n"
305 "exception={}\n--ONNX--\n{}".format(
306 list(sorted(inputs)), self.alg_class, dtypes,
307 shapes, exp, exp_types, e, self.onnx_)) from e
308 return tuple(res)
310 def need_context(self):
311 """
312 Tells the runtime if this node needs the context
313 (all the results produced so far) as it may silently access
314 one of them (operator Loop).
315 The default answer is `False`.
316 """
317 return False