Coverage for mlprodict/npy/onnx_numpy_compiler.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Implements :epkg:`numpy` functions with onnx and a runtime.
5.. versionadded:: 0.6
6"""
7import inspect
8import logging
9from typing import Any
10import numpy
11from ..onnx_tools.optim._main_onnx_optim import onnx_optimisations
12from .onnx_version import FctVersion
13from .onnx_numpy_annotation import get_args_kwargs
14from .xop_variable import Variable
15from .xop import OnnxOperator, OnnxOperatorTuple
18logger = logging.getLogger('xop')
21class OnnxNumpyFunction:
22 """
23 Class wrapping a function build with
24 @see cl OnnxNumpyCompiler.
26 .. versionadded:: 0.6
27 """
29 def __init__(self, compiler, rt, inputs, outputs,
30 n_optional, n_variables):
31 if any(map(lambda n: not isinstance(n, Variable), inputs)):
32 raise TypeError( # pragma: no cover
33 "All inputs must be of type Variable: %r." % (inputs, ))
34 if any(map(lambda n: not isinstance(n, Variable), outputs)):
35 raise TypeError( # pragma: no cover
36 "All outputs must be of type Variable: %r." % (outputs, ))
37 self.compiler = compiler
38 self.inputs = inputs
39 self.outputs = outputs
40 self.rt = rt
41 self.n_optional = n_optional
42 self.n_variables = n_variables
43 if n_optional < 0:
44 raise RuntimeError( # pragma: no cover
45 "Wrong configuration, n_optional %r must be >= 0."
46 "" % n_optional)
47 if n_optional >= len(inputs):
48 raise RuntimeError( # pragma: no cover
49 "Wrong configuration, n_optional %r must be >= %r "
50 "the number of inputs." % (n_optional, len(inputs)))
52 def _check_(self, *args, **kwargs):
53 if self.n_variables > 0:
54 return
55 if (len(args) < len(self.inputs) - self.n_optional or
56 len(args) > len(self.inputs)):
57 raise RuntimeError( # pragma: no cover
58 "Unexpected number of inputs %d. It should be in "
59 "[%r, %r] len(args)=%d n_optional=%d n_variables=%d"
60 "\nargs=%s\nkwargs=%s\ninputs=%s" % (
61 len(args), len(self.inputs) - self.n_optional,
62 len(args), self.n_optional, self.n_variables,
63 len(self.inputs), args, kwargs, self.inputs))
66class OnnxNumpyFunctionOnnxInference(OnnxNumpyFunction):
67 """
68 Overwrites @see cl OnnxNumpyFunction to run an instance of
69 @see cl OnnxInference.
71 .. versionadded:: 0.6
72 """
74 def __call__(self, *args, **kwargs):
75 self._check_(*args, **kwargs)
76 inp = {k.name: a for k, a in zip(self.inputs, args)}
77 out = self.rt.run(inp, **kwargs)
78 if len(out) != len(self.outputs):
79 raise RuntimeError( # pragma: no cover
80 "Unexpected number of outputs %d instead of %d." % (
81 len(out), len(self.outputs)))
82 return tuple([out[o.name] for o in self.outputs])
85class OnnxNumpyFunctionInferenceSession(OnnxNumpyFunction):
86 """
87 Overwrites @see cl OnnxNumpyFunction to run an instance of
88 `InferenceSession` from :epkg:`onnxruntime`.
90 .. versionadded:: 0.6
91 """
93 def __call__(self, *args, **kwargs):
94 self._check_(*args, **kwargs)
95 if len(kwargs) > 0:
96 raise RuntimeError( # pragma: no cover
97 "kwargs is not used but it is not empty: %r." % kwargs)
98 inp = {k.name: a for k, a in zip(self.inputs, args)}
99 out = self.rt.run(None, inp)
101 if len(out) != len(self.outputs):
102 raise RuntimeError( # pragma: no cover
103 "Unexpected number of outputs %d instead of %d." % (
104 len(out), len(self.outputs)))
105 return tuple(out)
108class OnnxNumpyCompiler:
109 """
110 Implements a class which runs onnx graph.
112 :param fct: a function with annotations which returns an ONNX graph,
113 it can also be an ONNX graph.
114 :param op_version: :epkg:`ONNX` opset to use, None
115 for the latest one
116 :param runtime: runtime to choose to execute the onnx graph,
117 `python`, `onnxruntime`, `onnxruntime1`
118 :param signature: used when the function is not annotated
119 :param version: the same function can be instantiated with
120 different type, this parameter is None or a numpy type
121 if the signature allows multiple types, it must an instance
122 of type @see cl FctVersion
123 :param fctsig: function used to overwrite the fct signature
124 in case this one is using `*args, **kwargs`
126 .. versionadded:: 0.6
127 """
129 def __init__(self, fct, op_version=None, runtime=None, signature=None,
130 version=None, fctsig=None):
131 if version is not None and not isinstance(version, FctVersion):
132 raise TypeError( # pragma: no cover
133 "version must be of Type 'FctVersion' not %s - %s"
134 "." % (type(version), version))
135 self.fctsig = fctsig
136 if op_version is None:
137 from .. import __max_supported_opset__
138 op_version = __max_supported_opset__
139 if hasattr(fct, 'SerializeToString'):
140 self.fct_ = None
141 self.onnx_ = fct
142 else:
143 self.fct_ = fct
144 if not inspect.isfunction(fct):
145 raise TypeError( # pragma: no cover
146 "Unexpected type for fct=%r, it must be a "
147 "function." % type(fct))
148 self.onnx_ = None
149 self.onnx_ = self._to_onnx(
150 op_version=op_version, signature=signature,
151 version=version)
152 self.runtime_ = self._build_runtime(
153 op_version=op_version, runtime=runtime,
154 signature=signature, version=version)
155 ann = self._parse_annotation(signature=signature, version=version)
156 inputs, outputs, kwargs, n_optional, n_variables = ann
157 n_opt = 0 if signature is None else signature.n_optional
158 args, kwargs2 = get_args_kwargs(self.fctsig or self.fct_, n_opt)
159 self.meta_ = dict(op_version=op_version, runtime=runtime,
160 signature=signature, version=version,
161 inputs=inputs, outputs=outputs,
162 kwargs=kwargs, n_optional=n_optional,
163 n_variables=n_variables,
164 args=args, kwargs2=kwargs2,
165 annotations=self.fct_.__annotations__)
167 def __getstate__(self):
168 """
169 Serializes everything but function `fct_`.
170 Function `fct_` is used to build the onnx graph
171 and is not needed anymore.
172 """
173 return dict(onnx_=self.onnx_, meta_=self.meta_)
175 def __setstate__(self, state):
176 """
177 Restores serialized data.
178 """
179 for k, v in state.items():
180 setattr(self, k, v)
181 self.runtime_ = self._build_runtime(
182 op_version=self.meta_['op_version'],
183 runtime=self.meta_['runtime'],
184 signature=self.meta_['signature'],
185 version=self.meta_['version'])
187 def __repr__(self):
188 "usual"
189 if self.fct_ is not None:
190 return "%s(%s)" % (self.__class__.__name__, repr(self.fct_))
191 if self.onnx_ is not None:
192 return "%s(%s)" % (self.__class__.__name__, "... ONNX ... ")
193 raise NotImplementedError( # pragma: no cover
194 "fct_ and onnx_ are empty.")
196 def _to_onnx_shape(self, shape):
197 if shape is Any or shape is Ellipsis:
198 shape = None
199 elif isinstance(shape, tuple):
200 shape = [None if s is Any or s is Ellipsis else s
201 for s in shape]
202 else:
203 raise RuntimeError( # pragma: no cover
204 "Unexpected annotated shape %r." % shape)
205 return shape
207 def _parse_annotation(self, signature, version):
208 """
209 Returns the annotations for function `fct_`.
211 :param signature: needed if the annotation is missing,
212 then version might be needed to specify which type
213 to use if the signature allows many
214 :param version: version inside the many signatures possible
215 :return: *tuple(inputs, outputs, kwargs)*, each of them
216 is a list of tuple with the name and the dtype,
217 *kwargs* is the list of additional parameters
218 """
219 n_opt = 0 if signature is None else signature.n_optional
220 if hasattr(self, 'meta_'):
221 args, kwargs = self.meta_['args'], self.meta_['kwargs2']
222 else:
223 args, kwargs = get_args_kwargs(self.fctsig or self.fct_, n_opt)
224 if version is not None:
225 nv = len(version) - len(args) - n_opt
226 if (signature is not None and not
227 signature.n_variables and nv > len(kwargs)):
228 raise RuntimeError( # pragma: no cover
229 "Mismatch (%d - %d - %d ? %d) between version=%r and kwargs=%r for "
230 "function %r, optional argument is %d, "
231 "signature=%r." % (
232 len(version), len(args), n_opt, len(kwargs),
233 version, kwargs, self.fct_,
234 signature.n_variables, signature))
235 vvers = {} if version.kwargs is None else version.kwargs
236 up = {}
237 for k, v in zip(kwargs, vvers):
238 up[k] = v
239 kwargs = kwargs.copy()
240 kwargs.update(up)
242 for k, v in kwargs.items():
243 if isinstance(v, (type, numpy.dtype)):
244 raise RuntimeError( # pragma: no cover
245 "Unexpected value for argument %r: %r from %r." % (
246 k, v, kwargs))
248 if signature is not None:
249 inputs, kwargs, outputs, n_optional, n_variables = (
250 signature.get_inputs_outputs(args, kwargs, version))
251 inputs = [Variable(i[0], i[1]) for i in inputs]
252 outputs = [Variable(i[0], i[1]) for i in outputs]
253 return inputs, outputs, kwargs, n_optional, n_variables
255 def _possible_names():
256 yield 'y'
257 yield 'z' # pragma: no cover
258 yield 'o' # pragma: no cover
259 for i in range(0, 10000): # pragma: no cover
260 yield 'o%d' % i
262 if hasattr(self, 'meta_'):
263 annotations = self.meta_['annotations']
264 else:
265 annotations = self.fct_.__annotations__
266 inputs = []
267 outputs = []
268 for a in args:
269 if a == "op_version":
270 continue
271 if a not in annotations:
272 raise RuntimeError( # pragma: no cover
273 "Unable to find annotation for argument %r. "
274 "You should annotate the arguments and the results "
275 "or specify a signature." % a)
276 ann = annotations[a]
277 shape, dtype = ann.__args__
278 shape = self._to_onnx_shape(shape)
279 inputs.append(Variable(a, dtype, shape=shape))
281 ret = annotations['return']
282 names_in = set(inp.name for inp in inputs)
284 if isinstance(ret, tuple):
285 # multiple outputs
286 names_none = set()
287 for shape_dtype in ret:
288 shape, dtype = shape_dtype.__args__
289 shape = self._to_onnx_shape(shape)
290 name_out = None
291 for name in _possible_names():
292 if name not in names_in and name not in names_none:
293 name_out = name
294 break
295 outputs.append(Variable(name_out, dtype, shape=shape))
296 names_none.add(name_out)
297 return (inputs, outputs, kwargs, 0,
298 signature.n_variables if signature is not None else False)
300 # single outputs
301 shape, dtype = ret.__args__
302 shape = self._to_onnx_shape(shape)
303 name_out = None
304 for name in _possible_names():
305 if name not in names_in:
306 name_out = name
307 break
308 outputs.append(Variable(name_out, dtype, shape=shape))
309 return (inputs, outputs, kwargs, 0,
310 signature.n_variables if signature is not None else False)
312 def _find_hidden_algebras(self, onx_var, onx_algebra):
313 """
314 Subgraph are using inputs not linked to the others nodes.
315 This function retrieves them as they are stored in
316 attributes `alg_hidden_var_`. The function looks into every
317 node linked to the inputs and their predecessors.
319 :param onx_var: @see cl OnnxVar
320 :param onx_algebra: OnnxOperator
321 :return: tuple(dictionary `{id(obj): (var, obj)}`,
322 all instance of @see cl OnnxVarGraph)
323 """
324 keep_hidden = {}
325 var_graphs = []
326 stack = [onx_var]
327 while len(stack) > 0:
328 var = stack.pop()
329 hidden = getattr(var, 'alg_hidden_var_', None)
330 if hidden is not None:
331 if any(map(lambda x: len(x) > 0,
332 var.alg_hidden_var_inputs.values())):
333 keep_hidden.update(hidden)
334 var_graphs.append(var)
335 if hasattr(var, 'inputs'):
336 for inp in var.inputs:
337 stack.append(inp)
338 return keep_hidden, var_graphs
340 def _to_onnx(self, op_version=None, signature=None, version=None):
341 """
342 Returns the onnx graph produced by function `fct_`.
343 """
344 if self.onnx_ is None and self.fct_ is not None:
345 from .onnx_variable import OnnxVar
346 logger.debug('OnnxNumpyCompiler._to_onnx(op_version=%r, '
347 'signature=%r, version=%r)',
348 op_version, signature, version)
349 inputs, outputs, kwargs, n_optional, n_variables = ( # pylint: disable=W0612
350 self._parse_annotation(
351 signature=signature, version=version))
352 if ((signature is None or not signature.n_variables) and
353 isinstance(version, tuple) and
354 len(inputs) > len(version)):
355 raise NotImplementedError( # pragma: no cover
356 "Mismatch between additional parameters %r "
357 "(n_optional=%r) and version %r for function %r from %r."
358 "" % (kwargs, n_optional, version, self.fct_,
359 getattr(self.fct_, '__module__', None)))
360 names_in = [oi.name for oi in inputs]
361 names_out = [oi.name for oi in outputs]
362 names_var = [OnnxVar(n, dtype=dt.dtype)
363 for n, dt in zip(names_in, inputs)]
365 logger.debug('OnnxNumpyCompiler._to_onnx:names_in=%r', names_in)
366 logger.debug('OnnxNumpyCompiler._to_onnx:names_out=%r', names_out)
368 if 'op_version' in self.fct_.__code__.co_varnames:
369 onx_var = None
370 onx_algebra = self.fct_(
371 *names_in, op_version=op_version, **kwargs)
372 else:
373 onx_var = self.fct_(*names_var, **kwargs)
374 if not hasattr(onx_var, 'to_algebra'):
375 raise TypeError( # pragma: no cover
376 "The function %r to convert must return an instance of "
377 "OnnxVar but returns type %r." % (self.fct_, type(onx_var)))
378 onx_algebra = onx_var.to_algebra(op_version=op_version)
380 logger.debug('OnnxNumpyCompiler._to_onnx:onx_var=%r',
381 type(onx_var))
382 logger.debug('OnnxNumpyCompiler._to_onnx:onx_algebra=%r',
383 type(onx_algebra))
385 if not isinstance(onx_algebra, (OnnxOperator, OnnxOperatorTuple)):
386 raise TypeError( # pragma: no cover
387 "Unexpected type for onx_algebra %r "
388 "(It should be OnnxOperator or OnnxOperatorItem), "
389 "function is %r." % (type(onx_algebra), self.fct_))
390 hidden_algebras, var_graphs = self._find_hidden_algebras(
391 onx_var, onx_algebra)
392 if len(hidden_algebras) > 0:
393 logger.debug( # pragma: no cover
394 'OnnxNumpyCompiler._to_onnx:len(hidden_algebras)=%r',
395 len(hidden_algebras))
396 # print('----1', len(var_graphs))
397 # for gr in var_graphs:
398 # print(type(gr), dir(gr))
399 # print('----2', len(hidden_algebras))
400 # for k, v in hidden_algebras.items():
401 # print("*", type(v.alg_), dir(v.alg_))
402 # #import pprint
403 # #pprint.pprint(dir(v.alg_))
404 raise NotImplementedError( # pragma: no cover
405 "Subgraphs only support constants (operator If, Loop, "
406 "Scan). hidden_algebras=%r var_graphs=%r" % (
407 hidden_algebras, var_graphs))
409 if isinstance(onx_algebra, str):
410 raise RuntimeError( # pragma: no cover
411 "Unexpected str type %r." % onx_algebra)
412 if isinstance(onx_algebra, tuple):
413 raise NotImplementedError( # pragma: no cover
414 "Not implemented when the function returns multiple results.")
415 if hasattr(onx_algebra, 'to_onnx'):
416 onx_algebra.output_names = [Variable(n) for n in names_out]
417 onx = onx_algebra.to_onnx(
418 inputs=inputs, target_opset=op_version, outputs=outputs)
419 # optimisation
420 onx_optimized = onnx_optimisations(onx)
421 self.onnx_ = onx_optimized
423 if self.onnx_ is None:
424 raise RuntimeError( # pragma: no cover
425 "Unable to get the ONNX graph (class %r, fct_=%r)" % (
426 type(self), self.fct_))
427 return self.onnx_
429 def to_onnx(self, **kwargs):
430 """
431 Returns the ONNX graph for the wrapped function.
432 It takes additional arguments to distinguish between multiple graphs.
433 This happens when a function needs to support multiple type.
435 :return: ONNX graph
436 """
437 if len(kwargs) > 0:
438 raise NotImplementedError( # pragma: no cover
439 "kwargs is not empty, this case is not implemented. "
440 "kwargs=%r." % kwargs)
441 if hasattr(self, 'onnx_'):
442 return self.onnx_
443 raise NotImplementedError( # pragma: no cover
444 "Attribute 'onnx_' is missing.")
446 def _build_runtime(self, op_version=None, runtime=None,
447 signature=None, version=None):
448 """
449 Creates the runtime for the :epkg:`ONNX` graph.
451 :param op_version: :epkg:`ONNX` opset to use, None
452 for the latest one
453 :param runtime: runtime to choose to execute the onnx graph,
454 `python`, `onnxruntime`, `onnxruntime1`
455 :param signature: used when the function is not annotated
456 """
457 onx = self._to_onnx(op_version=op_version, signature=signature,
458 version=version)
459 inputs, outputs, _, n_optional, n_variables = self._parse_annotation(
460 signature=signature, version=version)
461 if runtime not in ('onnxruntime', 'onnxruntime-cuda'):
462 from ..onnxrt import OnnxInference
463 rt = OnnxInference(onx, runtime=runtime)
464 self.rt_fct_ = OnnxNumpyFunctionOnnxInference(
465 self, rt, inputs=inputs, outputs=outputs,
466 n_optional=n_optional, n_variables=n_variables)
467 else:
468 from ..tools.ort_wrapper import InferenceSession
469 rt = InferenceSession(onx.SerializeToString(), runtime=runtime)
470 self.rt_fct_ = OnnxNumpyFunctionInferenceSession(
471 self, rt, inputs=inputs, outputs=outputs,
472 n_optional=n_optional, n_variables=n_variables)
473 return self.rt_fct_
475 def __call__(self, *args, **kwargs):
476 """
477 Executes the function and returns the results.
479 :param args: arguments
480 :return: results
481 """
482 res = self.rt_fct_(*args, **kwargs)
483 if len(res) == 1:
484 return res[0]
485 return res