Coverage for mlprodict/npy/xop.py: 90%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=E1101,C0302
2"""
3@file
4@brief Xop API to build onnx graphs. Inspired from :epkg:`sklearn-onnx`.
6.. versionadded:: 0.9
7"""
8import os
9import pprint
10import logging
11import hashlib
12from collections import OrderedDict
13import numpy
14from scipy.sparse.coo import coo_matrix
15import onnx
16from onnx import GraphProto, TensorProto, ValueInfoProto
17from onnx.helper import (
18 make_node, make_graph, make_model, make_value_info,
19 make_tensor_value_info, make_function, make_opsetid,
20 make_tensor_type_proto, make_operatorsetid)
21from onnx.numpy_helper import from_array, to_array
22from onnx.shape_inference import infer_shapes
23from ._cache import cache_folder
24from .xop_variable import (
25 Variable, is_numpy_dtype, numpy_type_prototype, max_supported_opset,
26 DetectedVariable, InputDetectedVariable, OutputDetectedVariable,
27 NodeResultName, guess_numpy_type)
28from .xop_auto import get_rst_doc
31logger = logging.getLogger('xop')
34def _default_OPSET_TO_IR_VERSION():
35 """
36 Returns the default mapping between opset and ir_version.
38 .. runpython::
39 :showcode:
41 import pprint
42 from mlprodict.npy.xop import _default_OPSET_TO_IR_VERSION
43 pprint.pprint(_default_OPSET_TO_IR_VERSION())
44 """
45 return {
46 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3,
47 7: 3, 8: 4, 9: 4, 10: 5, 11: 6, 12: 7,
48 13: 7, 14: 7, 15: 8, 16: 8}
51def _domain_to_class_name(domain):
52 """
53 Converts domain into a name.
55 :param domain: domain name such as `ai.onnx.ml`
56 :return: string
58 .. runpython::
59 :showcode:
61 from mlprodict.npy.xop import _domain_to_class_name
62 print(_domain_to_class_name('ai.onnx.ml'))
63 """
64 if domain == 'ai.onnx':
65 return ''
66 dom = domain.split('.')
67 res = []
68 for d in dom:
69 if len(d) == 0:
70 res.append(d)
71 elif len(d) == 1:
72 res.append(d.upper())
73 else:
74 res.append(d[0].upper() + d[1:])
75 return "".join(res)
78def _populate_schemas():
79 """
80 Populates all schemas.
81 """
82 res = {}
83 versions = {}
84 domains = {}
85 for schema in onnx.defs.get_all_schemas_with_history():
86 if schema.support_level == schema.SupportType.EXPERIMENTAL:
87 # Skips experimental operators.
88 continue
89 # Multiple version can coexist. The last one is kept.
90 if schema.name in res:
91 if schema.since_version > res[schema.name].since_version:
92 # We keep the most recent one.
93 res[schema.domain, schema.name] = schema
94 else:
95 res[schema.domain, schema.name] = schema
96 full_name = schema.name + '_' + str(schema.since_version)
97 res[schema.domain, full_name] = schema
98 key = schema.domain, schema.name
99 if key not in versions:
100 versions[key] = set()
101 if schema.name not in domains:
102 domains[schema.name] = set()
103 domains[schema.name].add(schema.domain)
104 versions[key].add(full_name)
105 return res, versions, domains
108def _find_operator_domain(name):
109 """
110 Determines the domain of an operator.
111 Raises an exception if not found or if there is an ambiguity.
113 :param name: operator name
114 :return: domain
115 """
116 if name not in _all_domains:
117 raise ValueError(
118 "Unable to guess domain for operator %r. "
119 "Not found in %r." % (name, list(_all_domains)))
120 domains = _all_domains[name]
121 if len(domains) == 1:
122 return list(domains)[0]
123 raise ValueError( # pragma: no cover
124 "Unable to guess domain of operator %r, found domains %r." % (
125 name, domains))
128def ClassFactory(class_name, op_name, inputs, outputs,
129 input_range, output_range,
130 domain, attr_names, doc,
131 deprecated, since_version,
132 past_version):
133 """
134 Dynamically creates a class for a specific operator.
136 :param class_name: class name
137 :param op_name: operator type
138 :param inputs: expected inputs
139 :param outputs: expected outputs
140 :param input_range: input range
141 :param output_range: output_range
142 :param domain: domain
143 :param attr_names: attributes names
144 :param doc: docstring
145 :param deprecated: is the operator deprecated
146 :param since_version: available since version
147 :param past_version: list of versions
148 """
150 def __init__(self, *args, **kwargs):
152 op_version = kwargs.pop('op_version', None)
153 if isinstance(op_version, dict):
154 op_version = op_version.get(domain, None)
156 if op_version is None:
157 if len(args) == 0 and input_range[0] == input_range[1]:
158 args = [_[0] for _ in self.__class__.expected_inputs]
159 if not (input_range[0] <= len(args) <= input_range[1]):
160 raise RuntimeError( # pragma: no cover
161 "Unexpected number of inputs, "
162 "got {}, expecting {} for operator "
163 "'{}'.".format(
164 len(args), len(inputs), op_name))
166 attr_names = self.attr_names
167 if '_' in self.__class__.__name__:
168 op_version_class = int(self.__class__.__name__.split('_')[-1])
169 if op_version is None:
170 op_version = op_version_class
171 try:
172 op_version = min(op_version, op_version_class)
173 except TypeError: # pragma: no cover
174 raise TypeError( # pylint: disable=W0707
175 "Could not compare versions {} ? {} for "
176 "class '{}' since_version {}. Parameter 'op_version' "
177 "is probably missing when the class "
178 "is instantiated.".format(
179 op_version, op_version_class, class_name,
180 since_version))
181 else:
182 op_version_class = None
184 # By default, the op_version is None.
185 # None means the latest available.
186 if op_version is None:
187 op_version = since_version
189 found = None
190 if op_version is not None:
191 # attr_names refers to the most recent version of
192 # this operator. We may need an older one.
193 for op in range(op_version, 0, -1):
194 name = '{}_{}'.format(self.__class__.__name__, op)
195 if name in self.past_version:
196 found = (name, op)
197 attr_names = self.past_version[name].attr_names
198 break
199 if (op_version_class is not None and found is not None and
200 found[-1] != op_version_class):
201 raise RuntimeError( # pragma: no cover
202 "op_version={} does not refer to the same opset as the class "
203 "name ('{}').".format(op_version, self.__class__.__name__))
204 for key in kwargs:
205 if key in {'output_names', 'op_version', 'domain', 'ir_version',
206 'global_context', 'clear_subgraph_inputs'}:
207 continue
208 if key not in attr_names:
209 raise TypeError( # pragma: no cover
210 "Argument '%s' not valid for '%s' opset=%s."
211 % (key, op_name, op_version))
213 if op_version is not None:
214 kwargs['op_version'] = op_version
215 # This class can only be created by a user. Let's check
216 # types are either a variable, an operator or an array.
217 for i, a in enumerate(args):
218 if isinstance(a, tuple):
219 if len(a) != 2:
220 raise TypeError( # pragma: no cover
221 "Input %r is a tuple or class %r, it must have two "
222 "elements (name, type) not %r." % (i, class_name, a))
223 if not isinstance(a[0], str):
224 raise TypeError( # pragma: no cover
225 "Input %r is a tuple or class %r, it must be a tuple "
226 "(name, type) not %r." % (i, class_name, a))
227 continue
228 if not isinstance(a, (
229 Variable, OnnxOperator, numpy.ndarray, str,
230 OnnxOperatorItem, coo_matrix)):
231 raise TypeError( # pragma: no cover
232 "Unexpected type %r for input %r of operator %r. "
233 "It must be an instance of Variable (or a string), "
234 "OnnxOperator, OnnxOperatorItem, numpy.ndarray, "
235 "coo_matrix)." % (
236 type(a), i, class_name))
237 OnnxOperator.__init__(self, *args, **kwargs)
239 newclass = type(class_name, (OnnxOperator,),
240 {"__init__": __init__, '__doc__': doc,
241 'expected_inputs': inputs,
242 'expected_outputs': outputs,
243 'operator_name': op_name,
244 'input_range': input_range,
245 'output_range': output_range,
246 'domain': domain,
247 'is_deprecated': deprecated,
248 'since_version': since_version,
249 'past_version': past_version,
250 'attr_names': attr_names,
251 'op_type': op_name,
252 '__module__': __name__})
253 return newclass
256def _dynamic_class_creation(operator_names=None, cache=False, include_past=False,
257 verbose=0, fLOG=print):
258 """
259 Automatically generates classes for each of the operators
260 module *onnx* defines and described at
261 `Operators
262 <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
263 and `Operators
264 <https://github.com/onnx/onnx/blob/master/docs/
265 Operators-ml.md>`_.
267 :param operator_names: list of operators to request or None for all
268 :param cache: extract the documentation from onnx package and
269 saves it on disk it True
270 :param include_past: includes past versions if operator_names is None
271 :param verbose: display some progress
272 :param fLOG: logging function
273 :return: list of requested operators as a tuple
274 """
275 def _c(obj, label, i):
276 name = '%s%d' % (obj.name or label, i)
277 tys = obj.typeStr or ''
278 return (name, tys)
280 cache_dir = cache_folder()
281 if operator_names is None:
282 operator_names = list(_all_schemas_versions)
283 if include_past:
284 add = []
285 for domain, op in operator_names:
286 add.extend(
287 [(domain, k)
288 for k in _all_schemas_versions[domain, op]])
289 operator_names.extend(add)
290 operator_names.sort()
292 # type verification
293 ops = []
294 for name in operator_names:
295 if isinstance(name, str):
296 if name.startswith('Onnx'):
297 raise ValueError(
298 "Operator name cannot start with Onnx: %r." % name)
299 domain = _find_operator_domain(name.split('_', maxsplit=1)[0])
300 ops.append((domain, name))
301 elif isinstance(name, tuple) and len(name) == 2:
302 if name[1].startswith('Onnx'):
303 raise ValueError( # pragma: no cover
304 "Operator name cannot starts with Onnx: %r." % name)
305 ops.append(name)
306 else:
307 raise ValueError( # pragma: no cover
308 "Operator to fetch must be a string or a "
309 "`tuple(domain, name)` not %r." % (name))
310 operator_names = ops
312 # versions
313 res = _all_schemas
314 cls = {}
315 set_names = dict()
316 set_skip = set()
317 for pos, (op_domain, op_name) in enumerate(operator_names):
318 if op_domain == 'ai.onnx':
319 op_domain = ''
320 set_names[op_domain, op_name] = pos
321 if '_' in op_name and not include_past:
322 n = op_name.split('_')[0]
323 set_skip.add((op_domain, n))
324 if n not in set_names:
325 set_names[op_domain, n] = -1
327 if verbose > 1 and fLOG is not None:
328 fLOG( # pragma: no cover
329 "[_dynamic_class_creation] set_names=%r" % set_names)
330 fLOG( # pragma: no cover
331 "[_dynamic_class_creation] set_skip=%r" % set_skip)
333 returned_classes = []
334 positions = {}
336 for (op_domain, op_name), position in set_names.items():
337 cl_name = 'Onnx' + _domain_to_class_name(op_domain) + op_name
338 if verbose > 3 and fLOG is not None:
339 fLOG( # pragma: no cover
340 '[_dynamic_class_creation] cl_name=%r op_domain=%r op_name=%r (in=%d)' % (
341 cl_name, op_domain, op_name, 1 if cl_name in _all_classes else 0))
342 if cl_name in _all_classes:
343 if cl_name not in set_skip:
344 if position >= 0:
345 returned_classes.append((position, _all_classes[cl_name]))
346 continue
348 # operator name without domain
349 if '_' in op_name:
350 names = [op_name]
351 else:
352 try:
353 names = _all_schemas_versions[op_domain, op_name].copy()
354 except KeyError as e: # pragma: no cover
355 raise ValueError(
356 "Operator %r (domain=%r) does not exists." % (
357 op_name, op_domain)) from e
358 names.add(op_name)
360 if verbose > 0 and fLOG is not None:
361 fLOG( # pragma: no cover
362 "[_dynamic_class_creation] op_domain=%r op_name=%r, cl_name=%r names=%r"
363 "" % (op_domain, op_name, cl_name, names))
365 for name in names:
366 try:
367 schema = res[op_domain, name]
368 except KeyError as e:
369 raise ValueError(
370 "Operator (%r, %r) does not exists (available=%r)" % (
371 op_domain, name, pprint.pformat(list(res)))) from e
372 inputs = [_c(o, 'I', i) for i, o in enumerate(schema.inputs)]
373 outputs = [_c(o, 'O', i) for i, o in enumerate(schema.outputs)]
374 args = [p for p in schema.attributes]
376 if '_' in name:
377 class_name = "Onnx" + _domain_to_class_name(op_domain) + name
378 else:
379 class_name = (
380 "Onnx" + _domain_to_class_name(op_domain) + schema.name)
382 if verbose > 0 and fLOG is not None:
383 fLOG( # pragma: no cover
384 "[_dynamic_class_creation] op_name=%r, cl_name=%r cache=%r"
385 "" % (op_name, class_name, cache))
387 filename = os.path.join(
388 cache_dir,
389 schema.name + '_' + str(schema.since_version) + ".rst")
390 if not cache and os.path.exists(filename):
391 with open(filename, "r", encoding="utf-8") as f: # pragma: no cover
392 doc = f.read()
393 else:
394 doc = get_rst_doc(schema)
395 if cache: # pragma: no cover
396 with open(filename, 'w', encoding='utf-8') as f:
397 f.write(doc)
399 cl = ClassFactory(class_name, schema.name, inputs, outputs,
400 [schema.min_input, schema.max_input],
401 [schema.min_output, schema.max_output],
402 schema.domain, args,
403 "**Version**" + doc.split('**Version**')[-1],
404 getattr(schema, 'deprecated', False),
405 schema.since_version, {})
406 cls[class_name] = cl
407 if name == op_name:
408 positions[class_name] = position
410 # Retrieves past classes.
411 for name in cls: # pylint: disable=C0206
412 if '_' not in name:
413 continue
414 main, _ = name.split('_')
415 if main in cls: # pylint: disable=R1715
416 last = cls[main]
417 else:
418 last = _all_classes[main]
419 last.past_version[name] = cls[name]
421 # final
422 _all_classes.update(cls)
423 for cl_name, v in cls.items():
424 if v not in set_skip and positions.get(cl_name, -1) >= 0:
425 returned_classes.append((positions[cl_name], v))
427 returned_classes.sort()
428 return tuple(e[1] for e in returned_classes)
431def loadop(*names, cache=False, verbose=0, fLOG=print):
432 """
433 Dynamically creates a class for a every operator type in
434 the given list.
435 """
436 res = _dynamic_class_creation(
437 names, cache=cache, verbose=verbose, fLOG=fLOG)
438 if len(res) == 1:
439 return res[0]
440 return res
443class OnnxLoadFactory:
444 """
445 Automatically creating all operators from onnx packages
446 takes time. That's why function @see cl loadop only creates
447 classes for the requested operators. This class does the same
448 when an attributes is requested.
450 ::
452 cl = OnnxLoadOperators()
453 x = cl.Add(...)
455 It is equivalent to:
457 ::
459 OnnxAdd = loadop('Add')
460 x = OnnxAdd(...)
461 """
463 def __init__(self):
464 self._loaded_classes = {}
466 def __getattr__(self, name):
467 """
468 Enables expressions such as:
470 ::
472 ops = OnnxLoadFactory()
473 op = ops.Abs('X')
474 """
475 if name == '_loaded_classes':
476 return self._loaded_classes
477 if name in self._loaded_classes:
478 return self._loaded_classes[name]
479 cl = loadop(name)
480 self._loaded_classes[name] = cl
481 self._loaded_classes[cl.__name__] = cl
482 return cl
485class OnnxOperatorBase:
486 """
487 Base class for @see cl OnnxOperator, @see cl OnnxOperatorItem,
488 @see cl OnnxOperatorTuple.
489 """
491 def __init__(self):
492 pass
494 def add_to(self, builder):
495 "This method should be overwritten."
496 raise NotImplementedError( # pragma: no cover
497 "Not overwritten for class %r." % type(self))
499 @property
500 def output_names(self):
501 "This method should be overwritten."
502 raise NotImplementedError( # pragma: no cover
503 "Not overwritten for class %r." % type(self))
505 def find_named_inputs(self):
506 """
507 Returns all inputs to the graph.
508 """
509 raise NotImplementedError( # pragma: no cover
510 "Method 'find_named_inputs' must be overloaded for type %s."
511 "" % type(self))
513 def f(self, *args, **kwargs):
514 """
515 Evaluates this node.
516 """
517 raise NotImplementedError( # pragma: no cover
518 "Method 'f' must be overloaded for type %s." % type(self))
521class OnnxOperatorItem(OnnxOperatorBase):
522 """
523 Accessor to one of the output returned by a @see cl OnnxOperator.
525 :param onx_op: @see cl OnnxOperator
526 :param index: integer
527 :param op_version: defines the opset version
528 """
530 def __init__(self, onx_op, index, op_version=None):
531 OnnxOperatorBase.__init__(self)
532 if not isinstance(index, int):
533 raise TypeError( # pragma: no cover
534 "index must be an integer not %r." % type(index))
535 logger.debug("OnnxOperatorItem(%r, %d, op_version=%r)",
536 onx_op, index, op_version)
537 if not isinstance(onx_op, OnnxOperatorBase):
538 raise TypeError( # pragma: no cover
539 "onx_op must be an OnnxOperator not %r." % type(onx_op))
540 self.onx_op = onx_op
541 self.index = index
542 self.op_version = op_version
544 @property
545 def output_names(self):
546 "Returns None."
547 return None
549 @property
550 def inputs(self):
551 "Returns the only inputs in a list."
552 return [NodeResultName(self.onx_op, self.index)]
554 def add_to(self, builder):
555 """
556 Adds to graph builder.
557 Does nothing because the original node is already added.
559 :param builder: instance of @see cl _GraphBuilder,
560 it must have a method `add_node`
561 """
562 pass
564 def __str__(self):
565 "usual"
566 return "%s[%d]" % (str(self.onx_op), self.index)
568 def __repr__(self):
569 "usual"
570 return "%s(%s[%d])" % (
571 self.__class__.__name__,
572 self.onx_op.__class__.__name__,
573 self.index)
575 def get_output_result(self, i=0):
576 """
577 Returns the output name at position *i*.
578 """
579 if i != 0:
580 raise IndexError( # pragma: no cover
581 "Can only return the first item.")
582 return self.onx_op.get_output_result(self.index)
584 def find_named_inputs(self):
585 """
586 Returns all inputs to the graph.
587 """
588 return self.onx_op.find_named_inputs()
590 def f(self, *inputs, verbose=0, fLOG=None, # pylint: disable=W0221
591 clear_cache=False, runtime=None):
592 """
593 Computes the predictions for this node.
594 Similar to an eager evaluation.
596 :param inputs: inputs as dictionary or a list of inputs
597 (see below)
598 :param verbose: display information while predicting
599 :param fLOG: logging function if *verbose > 0*
600 :param clear_cache: onnx graph is created once unless
601 this parameter is True
602 :param runtime: runtime to use for the evaluation,
603 see @see cl OnnxInference
604 :return: outputs as a dictionary if the input were given as a
605 dictionary or a single result or a tuple otherwise
607 The inputs refer to the inputs of the graph.
608 The method walks through all inputs and finds inputs defined as
609 string. It replaces them by the value found in the dictionary.
610 If the inputs are specified in a list, the function retrieves the
611 list of inputs defined as a string and assigns them a value.
612 Logging function can be used to get more insight about it.
613 During the evaluation every node is independently converted
614 into ONNX. The ONNX graph is cached in the class itself.
615 """
616 res = self.onx_op.f(*inputs, verbose=verbose, fLOG=fLOG,
617 clear_cache=clear_cache, runtime=runtime)
618 if isinstance(res, dict):
619 names = self.onx_op.output_names
620 if names is None:
621 names = self.onx_op.expected_outputs
622 name = names[self.index][0]
623 else:
624 name = names[self.index]
625 return {name: res[name]}
626 return res[self.index]
629class OnnxOperatorTuple(OnnxOperatorBase):
630 """
631 Class used to return multiple @see cl OnnxVar
632 at the same time.
633 """
635 def __init__(self, first, *args):
636 OnnxOperatorBase.__init__(self)
637 logger.debug("%s([%r], %d in)",
638 self.__class__.__name__, type(first), len(args))
639 if isinstance(first, (list, tuple)):
640 raise TypeError( # pragma: no cover
641 "Unexpected type for first %r." % type(first))
642 logger.debug('OnnxOperatorTuple(%d in)', 1 + len(args))
643 if len(args) > 0:
644 self.values = (first,) + args
645 self.unique = None
646 else:
647 self.values = None
648 self.unique = first
649 if self.values is not None and self.unique is not None:
650 raise RuntimeError( # pragma: no cover
651 "Unexpected configuration. One member (values or unique) must be "
652 "null, unique=%r, values=%r" % (self.unique, self.values))
653 if self.values is None and self.unique is None:
654 raise RuntimeError( # pragma: no cover
655 "Unexpected configuration. One member (values or unique) must be "
656 "not null.")
658 def __repr__(self):
659 "usual"
660 if self.values is None:
661 return "%s(%r)" % (self.__class__.__name__, type(self.unique))
662 return "%s(%s)" % (self.__class__.__name__, ", ".join(
663 "%r" % type(v) for v in self.values))
665 @property
666 def inputs(self):
667 "Returns the only inputs in a list."
668 if self.values is None:
669 return [self.unique]
670 raise NotImplementedError( # pragma: no cover
671 "OnnxOperatorTuple.inputs is missing.")
673 def add_to(self, builder):
674 """
675 Adds to graph builder.
676 Does nothing because the original node is already added.
678 :param builder: instance of @see cl _GraphBuilder,
679 it must have a method `add_node`
680 """
681 pass
683 def __len__(self):
684 "usual"
685 if self.values is None:
686 raise NotImplementedError( # pragma: no cover
687 "Not yet implemented in this case unique=%r, "
688 "values=%r." % (self.unique, self.values))
689 return len(self.values)
691 def __iter__(self):
692 "Iterates on the outputs."
693 if self.values is None:
694 raise NotImplementedError( # pragma: no cover
695 "Not yet implemented in this case.")
696 for v in self.values:
697 yield v
699 def __getitem__(self, i):
700 "usual"
701 if self.values is None:
702 return self.unique[i]
703 return self.values[i]
705 @property
706 def outputs(self):
707 "Returns 'output_names' of attribute 'unique'."
708 if self.values is None:
709 if hasattr(self.unique, 'to_onnx'):
710 return self.unique.outputs
711 raise NotImplementedError( # pragma: no cover
712 "Not implemented yet unique=%r values=%r." % (
713 self.unique, self.values))
715 @property
716 def output_names(self):
717 "Returns 'output_names' of attribute 'unique'."
718 if self.values is None:
719 if hasattr(self.unique, 'to_onnx'):
720 return self.unique.output_names
721 raise NotImplementedError( # pragma: no cover
722 "Not implemented yet unique=%r values=%r." % (
723 self.unique, self.values))
725 @output_names.setter
726 def output_names(self, value):
727 """
728 Updates 'output_names' of attribute 'unique'
729 or every output name of attribute 'values'.
730 """
731 logger.debug("OnnxOperatorTuple:output_names:set(%r)", value)
732 OnnxIdentity = loadop('Identity')
733 if self.values is None:
734 if (hasattr(self.unique, 'to_onnx') or
735 hasattr(self.unique, 'add_to')):
736 if len(value) > 1:
737 self.values = tuple(
738 OnnxIdentity(
739 self.unique[i], output_names=value[i:i + 1],
740 op_version=self.unique.op_version)
741 for i in range(0, len(value)))
742 self.unique = None
743 return
744 self.unique.output_names = [Variable(v) for v in value]
745 return
746 raise NotImplementedError( # pragma: no cover
747 "Not implemented yet, value=%r, unique=%r values=%r." % (
748 value, self.unique, self.values))
749 if self.values is not None and len(self.values) == len(value):
750 for name, v in zip(value, self.values):
751 v.output_names = [Variable(name)]
752 return
753 raise NotImplementedError( # pragma: no cover
754 "Not implemented yet, value=%r, unique=%r values=%r." % (
755 value, self.unique, self.values))
757 def to_onnx(self, inputs=None, outputs=None,
758 other_outputs=None, target_opset=None,
759 optim=True, verbose=0, run_shape=True):
760 """
761 Converts this operator into an ONNX graph.
762 It follows the same signature as :meth:`OnnxOperator.to_onnx
763 <mlprodict.npy.xop.OnnxOperator.to_onnx>` and calls this
764 method of the unique input object or the first one
765 if there are several. In that case, other inputs in
766 attribute `values` are moved into container
767 `other_outputs`.
768 """
769 if self.values is None:
770 return self.unique.to_onnx(
771 inputs=inputs, outputs=outputs, other_outputs=other_outputs,
772 target_opset=target_opset, optim=optim, verbose=verbose,
773 run_shape=run_shape)
774 new_other_outputs = self.values[1:]
775 if other_outputs is not None:
776 new_other_outputs.extend(other_outputs)
777 return self.values[0].to_onnx(
778 inputs=inputs, outputs=outputs, other_outputs=new_other_outputs,
779 target_opset=target_opset, optim=optim, verbose=verbose,
780 run_shape=run_shape)
783class OnnxOperator(OnnxOperatorBase):
784 """
785 Ancestor to every *ONNX* operator exposed in
786 :mod:`mlprodict.npy.xops` and :mod:`mlprodict.npy.xops_ml`.
788 :param inputs: list of inputs expected by the operator
789 :param op_version: to select a specific version of the operator
790 :param output_names: used defined names for the outputs
791 :param domain: to overwrite the default domain
792 :param global_context: operator *If* executes one subgraph
793 whose nodes may use one existing output in the current
794 context. If not used in the main graph, these operators
795 are not linked to the output and cannot be retrieved.
796 *global_context* is a dictionary mapped the subgraph input
797 names to these operators.
798 :param kwargs: additional parameters of the operator
800 .. versionadd:: 0.9
801 """
802 @classmethod
803 def __class_getitem__(cls, opset):
804 """
805 Enables expression `cls[opset]`. It returns the appropriate class
806 `cls_opset`. Parameter *op_version* should be specified.
807 """
808 if not isinstance(opset, int):
809 raise ValueError(
810 "opset must an integer not %r." % type(opset))
811 best = None
812 for _, v in cls.past_version.items():
813 if v.since_version == opset:
814 return lambda *args, **kwargs: v(
815 *args, op_version=opset, **kwargs)
816 if v.since_version <= opset and (
817 best is None or best.since_version < v.since_version):
818 best = v
819 if best is None:
820 raise ValueError(
821 "Unable to find a version of operator %r and opset %r." % (
822 cls.__name__, opset))
823 return lambda *args, **kwargs: best(
824 *args, op_version=opset, **kwargs)
826 def __init__(self, *inputs, op_version=None, output_names=None,
827 domain=None, global_context=None, **kwargs):
829 OnnxOperatorBase.__init__(self)
830 logger.debug("%s(%d in, op_version=%r, output_names=%r)",
831 self.__class__.__name__, len(inputs), op_version,
832 output_names)
833 if (output_names is None and
834 self.__class__.__name__.startswith("OnnxScan")):
835 raise NotImplementedError(
836 "The class cannot infer the number of variables "
837 "for node '{}' yet. output_names must be specified"
838 ".".format(self.__class__.__name__))
839 if isinstance(output_names, (str, Variable)):
840 output_names = [output_names]
841 if isinstance(output_names[0], str):
842 output_names[0] = Variable(output_names[0])
843 elif isinstance(output_names, list):
844 if len(output_names) == 0:
845 raise ValueError( # pragma: no cover
846 "output_names cannot be empty (operator %r)."
847 "" % self.__class__.__name__)
848 output_names = output_names.copy()
849 for i in range(len(output_names)): # pylint: disable=C0200
850 if isinstance(output_names[i], str):
851 output_names[i] = Variable(output_names[i])
852 elif output_names is not None:
853 raise TypeError( # pragma: no cover
854 "output_names must be a string or a list not %r."
855 "" % type(output_names))
857 if op_version is None:
858 if domain == '':
859 self.op_version = max_supported_opset()
860 else:
861 self.op_version = None
862 else:
863 self.op_version = op_version
864 self.since_version = self.__class__.since_version
866 if (self.op_version is not None and
867 self.op_version < self.since_version):
868 schema = self.find_schema(self.op_version)
869 self.since_version = schema.since_version
870 self.expected_inputs = schema.expected_inputs.copy()
871 self.expected_outputs = schema.expected_outputs.copy()
872 self.input_range = schema.input_range
873 self.output_range = schema.output_range
874 else:
875 self.expected_inputs = (
876 None if self.__class__.expected_inputs is None
877 else self.__class__.expected_inputs.copy())
878 self.expected_outputs = (
879 None if self.__class__.expected_outputs is None
880 else self.__class__.expected_outputs.copy())
881 self.input_range = self.__class__.input_range
882 self.output_range = self.__class__.output_range
883 if self.__class__.__name__ not in {
884 'OnnxScan', 'OnnxLoop', 'OnnxIf'}:
885 # The minimum opset depends on embedded graph
886 # by default, it takes the given op_version but the
887 # optimal value could be lower.
888 self.op_version = self.since_version
889 if self.op_version is None:
890 self.op_version = self.since_version
892 if (self.op_version is not None and
893 self.op_version < self.since_version):
894 raise RuntimeError( # pragma: no cover
895 "Operator '{}': requested version {} < "
896 "{} schema version.".format(
897 self.__class__.__name__,
898 self.op_version, self.since_version))
900 self.state = None
901 self.domain = domain
902 self.kwargs = kwargs
903 self.max_item_ = None
905 # check inputs
906 self.inputs = []
907 if len(inputs) > 0:
908 for inp in inputs:
909 if isinstance(inp, str):
910 self.inputs.append(Variable(inp))
911 elif isinstance(inp, tuple):
912 if len(inp) != 2:
913 raise RuntimeError( # pragma: no cover
914 "Unexpected tuple %r." % (inp, ))
915 self.inputs.append(
916 Variable(inp[0], dtype=guess_numpy_type(inp[1]),
917 shape=inp[1].shape))
918 elif isinstance(inp, (OnnxOperatorBase, Variable)):
919 self.inputs.append(inp)
920 elif isinstance(inp, (numpy.ndarray, coo_matrix, TensorProto)):
921 self.inputs.append(inp)
922 elif isinstance(inp, ValueInfoProto):
923 self.inputs.append(inp.type.tensor_type)
924 else:
925 raise TypeError( # pragma: no cover
926 "Unable to interpret the input name for type {} in "
927 "operator '{}' (value={}).".format(
928 type(inp), self.__class__.__name__, inp))
930 if (self.inputs is not None and
931 (len(self.inputs) < self.input_range[0] or
932 len(self.inputs) > self.input_range[1])):
933 raise RuntimeError( # pragma: no cover
934 "Operator '{}' expects a number of inputs in [{}, {}] not {} "
935 "(expected opset={}, class opset={})".format(
936 getattr(self, 'operator_name', '?'), *self.input_range,
937 len(self.inputs), op_version, self.op_version))
938 # global context
939 if global_context is None:
940 self.global_context = None
941 else:
942 if not isinstance(global_context, dict):
943 raise TypeError( # pragma: no cover
944 "global_context must be a dictionary not %r."
945 "" % type(global_context))
946 for k, v in global_context.items():
947 if not isinstance(v, OnnxOperatorBase):
948 raise TypeError( # pragma: no cover
949 "Value %r in must be an OnnxOperatorBase not %r."
950 "" % (k, type(v)))
951 self.global_context = global_context
953 # check output
954 self.output_names_ = output_names
955 self.output_variables = None
957 if self.output_names is not None:
958 if len(self.output_names) == 0:
959 raise ValueError( # pragma: no cover
960 "output_names can be None but cannot be empty for "
961 "operator %r." % self)
962 if self.output_variables is None:
963 self.output_variables = [None for o in self.output_names]
964 for i in range(len(self.output_names)): # pylint: disable=C0200
965 name = self.output_names[i]
966 if isinstance(name, Variable):
967 self.output_variables[i] = name
968 else:
969 raise TypeError( # pragma: no cover
970 "output_names must be a list of strings "
971 "and element %r is %r (%r)" % (
972 i, type(name), name))
973 if all(map(lambda x: x is None, self.output_variables)):
974 self.output_variables = None
976 if (self.output_names is not None and (
977 self.expected_outputs is None or
978 len(self.output_names) > len(self.expected_outputs))):
979 if self.expected_outputs is None:
980 self.expected_outputs = []
981 for i in range(len(self.expected_outputs),
982 len(self.output_names)):
983 self.expected_outputs.append((self.output_names[i], None))
985 if (self.expected_inputs is None or
986 len(self.inputs) > len(self.expected_inputs)):
987 if self.expected_inputs is None:
988 self.expected_inputs = []
989 for i in range(len(self.expected_inputs),
990 len(self.inputs)):
991 inp = self.inputs[i]
992 if isinstance(inp, str):
993 inp = (inp, None)
994 elif hasattr(inp, 'add_to'):
995 # OnnxOperator
996 existing = set(_[0] for _ in self.expected_inputs)
997 i = 10
998 name = "input%d" % (10 + i)
999 while name in existing:
1000 i += 1
1001 name = "input%d" % (10 + i)
1002 inp = (name, None)
1003 self.expected_inputs.append(inp)
1005 self._post_process_attributes()
1006 self._check()
1008 @property
1009 def output_names(self):
1010 "Returns `self.output_names_`."
1011 return self.output_names_
1013 @output_names.setter
1014 def output_names(self, value):
1015 logger.debug("OnnxOperator:output_names:set(%r)", value)
1016 self.output_names_ = value
1018 def _check(self):
1019 input_types = (Variable, OnnxOperatorBase, numpy.ndarray,
1020 TensorProto)
1021 for o in self.inputs:
1022 if not isinstance(o, input_types):
1023 raise TypeError( # pragma: no cover
1024 "Wrong type for inputs %r." % (
1025 self.inputs, ))
1026 if self.output_names is not None:
1027 for o in self.output_names:
1028 if not isinstance(o, Variable):
1029 raise TypeError( # pragma: no cover
1030 "Wrong type for output_names %r." % (
1031 self.output_names, ))
1033 def _post_process_attributes(self):
1034 """
1035 Walks through attributes and replaces them by ONNX values.
1036 """
1037 # Looks into attributes if there is any tuple
1038 # (GraphProto, OnnxOperator). In that case, the function
1039 # replaces the tuple by the graph proto and keeps
1040 # in attributes graph_algebra the OnnxOperator
1041 # which is the source of it.
1042 updates = {}
1043 graph_algebra = {}
1044 for k, v in self.kwargs.items():
1045 if isinstance(v, tuple) and isinstance(v[0], GraphProto):
1046 updates[k] = v[0]
1047 graph_algebra[k] = v[1]
1049 if len(graph_algebra) > 0:
1050 self.kwargs.update(updates)
1051 self.graph_algebra = graph_algebra
1053 if self.__class__.__name__ == "OnnxConstantOfShape":
1054 if "value" in self.kwargs:
1055 value = self.kwargs['value']
1056 if isinstance(value, TensorProto):
1057 return
1058 if isinstance(value, numpy.ndarray):
1059 if value.shape == (1, ):
1060 val = value[0]
1061 elif len(value.shape) == 0:
1062 val = value
1063 else:
1064 raise RuntimeError( # pragma: no cover
1065 "Unexpected shape %r for value, it must be "
1066 "an array of one element." % value.shape)
1067 self.kwargs['value'] = from_array(
1068 numpy.array([val], dtype=value.dtype))
1069 return
1070 raise TypeError( # pragma: no cover
1071 "Unexpected type %r for value. It should be an array "
1072 "of one element." % type(value))
1073 return
1075 if self.__class__.__name__ == "OnnxCast":
1076 if "to" in self.kwargs:
1077 value = self.kwargs['to']
1078 if not isinstance(value, int):
1079 try:
1080 to = numpy_type_prototype(value)
1081 except ValueError as e: # pragma: no cover
1082 raise ValueError(
1083 "Unable to convert argument to in operator cast, "
1084 "type is %r, value is %r." % (type(value), value)) from e
1085 self.kwargs['to'] = to
1086 return
1088 def update_max_item(self, index):
1089 """
1090 Some operators return a undefined number of outputs.
1091 The method is called when require one of them (with `__getitem__`)
1092 and keeps the greater requested index assuming the node does
1093 not output any result beyond that index.
1095 :param index: requested index
1096 """
1097 if self.max_item_ is None:
1098 self.max_item_ = index
1099 else:
1100 self.max_item_ = max(self.max_item_, index)
1101 if self.expected_outputs is None:
1102 self.expected_outputs = []
1103 while len(self.expected_outputs) <= self.max_item_:
1104 self.expected_outputs.append(
1105 (("NEWOUTPUT", len(self.expected_outputs)), None))
1107 def find_schema(self, op_version):
1108 """
1109 Checks if there is an existing schema for a specific version.
1111 :param op_version: requested version
1112 :return: schema
1113 """
1114 if not hasattr(self.__class__, 'past_version'):
1115 raise RuntimeError( # pragma: no cover
1116 "Missing attribute 'past_version', there is "
1117 "no other available schema.")
1118 found = None
1119 for v in self.past_version.values():
1120 if v.since_version > op_version:
1121 continue
1122 if found is None or v.since_version > found.since_version:
1123 found = v
1124 if found is None:
1125 raise RuntimeError( # pragma: no cover
1126 "Operator '{}': requested version {} < "
1127 "{} schema version (past_version {}).".format(
1128 self.__class__.__name__,
1129 op_version, self.since_version,
1130 [v.since_version for v in self.past_version.values()]))
1131 return found
1133 def __repr__(self):
1134 """
1135 usual
1136 """
1137 return "{}({} in) -> {}".format(
1138 self.__class__.__name__,
1139 len(self.inputs) if self.inputs is not None else 0,
1140 [str(o) for o in self.output_names]
1141 if self.output_names is not None else "?")
1143 def get_output_result(self, i=0):
1144 """
1145 Returns the output name at position *i*.
1146 """
1147 return NodeResultName(self, i)
1149 def __getitem__(self, index):
1150 """
1151 Returns an accessor to one of the output
1152 of this node.
1153 """
1154 self.update_max_item(index)
1155 return OnnxOperatorItem(self, index, self.op_version)
1157 def __iter__(self):
1158 """
1159 Allows expressions such as ``a, b = OnnxTopK(...)``.
1160 """
1161 n = None
1162 if self.output_names is not None:
1163 n = len(self.output_names)
1164 else:
1165 rg = self.output_range
1166 if rg[0] == rg[1] and rg[0] > 0:
1167 n = rg[0]
1168 if n is None and self.max_item_ is not None:
1169 n = self.max_item_ + 1
1170 if n is None:
1171 raise RuntimeError( # pragma: no cover
1172 "Unable to guess the number of outputs of node type %r. "
1173 "Uses operator [] to select a specific output." %
1174 self.__class__.__name__)
1175 if self.max_item_ is not None:
1176 n = max(n, self.max_item_ + 1)
1177 for i in range(n):
1178 yield self[i]
1180 def add_to(self, builder):
1181 """
1182 Adds to graph builder.
1184 :param builder: instance of @see cl _GraphBuilder,
1185 it must have a method `add_node`
1186 """
1187 logger.debug("%s.add_to(builder)", self.__class__.__name__)
1188 inputs = builder.get_input_names(self, self.inputs)
1189 if self.output_names is not None:
1190 n_outputs = len(self.output_names)
1191 elif self.expected_outputs is not None:
1192 n_outputs = len(self.expected_outputs)
1193 else:
1194 n_outputs = self.output_range[0]
1195 outputs = [builder.get_unique_output_name(NodeResultName(self, i))
1196 for i in range(n_outputs)]
1197 builder.add_node(
1198 self.operator_name,
1199 builder.get_unique_name(
1200 '_' + self.operator_name.lower(), reserved=False),
1201 inputs, outputs, domain=self.domain, opset=self.op_version,
1202 **self.kwargs)
1204 @staticmethod
1205 def _node_to_graph_preprocess_list(inputs):
1206 new_inputs = OrderedDict()
1207 for el in inputs:
1208 if isinstance(el, str):
1209 new_inputs[el] = Variable(el)
1210 elif isinstance(el, Variable):
1211 new_inputs[el.name] = el
1212 elif isinstance(el, tuple) and len(el) == 2:
1213 # sklearn-onnx
1214 new_inputs[el[0]] = Variable(
1215 el[0], guess_numpy_type(el[1]), el[1].shape)
1216 else:
1217 raise TypeError( # pragma: no cover
1218 "Unable to handle input type %r (%r)." % (type(el), el))
1219 return new_inputs
1221 @staticmethod
1222 def _node_to_graph_process_input(inputs, set_inputs, node, inp,
1223 new_inputs, new_stack, inputs_dtype,
1224 as_function=False):
1225 if not as_function and inputs is None and inputs_dtype is None:
1226 raise RuntimeError( # pragma: no cover
1227 "Both inputs and inputs_dtype cannot be None at the same time "
1228 "for inp=%r." % (inp, ))
1229 if isinstance(inp, OnnxOperator):
1230 new_stack.append(inp)
1231 elif isinstance(inp, OnnxOperatorItem):
1232 new_stack.append(inp)
1233 new_stack.append(inp.onx_op)
1234 elif isinstance(inp, OnnxOperatorTuple):
1235 # new_stack.append(inp)
1236 # new_stack.append(inp.onx_op)
1237 raise NotImplementedError( # pragma: no cover
1238 "Unable to guess inputs when one input is OnnxOperatorTuple.")
1239 elif isinstance(inp, Variable):
1240 if inp.name in set_inputs:
1241 return
1242 set_inputs.add(inp.name)
1243 if inputs is None and inputs_dtype is None:
1244 new_inputs.append(InputDetectedVariable(node, inp))
1245 elif isinstance(inputs, dict):
1246 if inp.name in inputs:
1247 new_inputs.append(
1248 InputDetectedVariable(
1249 node, inp.copy_merge(inputs[inp.name])))
1250 else:
1251 raise ValueError( # pragma: no cover
1252 "Unable to find input %r in %r." % (
1253 inp, inputs))
1254 elif inputs_dtype is not None:
1255 new_inputs.append(
1256 InputDetectedVariable(node, inp.copy_add(inputs_dtype)))
1257 elif isinstance(inputs, Variable):
1258 if inp.name == inputs.name:
1259 new_inputs.append(
1260 InputDetectedVariable(node, inp.copy_merge(inputs)))
1261 else:
1262 new_inputs.append(
1263 InputDetectedVariable(node, inp))
1264 else:
1265 raise RuntimeError( # pragma: no cover
1266 "Unable to handle inputs=%r." % inputs)
1267 elif isinstance(inp, numpy.ndarray):
1268 pass
1269 else:
1270 raise TypeError( # pragma: no cover
1271 "Unexpected input type %r in node type %r." % (
1272 type(inp), type(node)))
1274 @staticmethod
1275 def _node_to_graph_get_type(node, name=None, outputs=None,
1276 outputs_dtype=None):
1277 if outputs is None:
1278 return outputs_dtype
1279 if isinstance(outputs, Variable):
1280 if name is None:
1281 return outputs.dtype or outputs_dtype
1282 if isinstance(name, Variable):
1283 return outputs.dtype or name.dtype or outputs_dtype
1284 else:
1285 raise RuntimeError( # pragma: no cover
1286 "Unable to handle outputs=%r." % outputs)
1287 if isinstance(outputs, dict):
1288 if name is None:
1289 raise RuntimeError( # pragma: no cover
1290 "Unable to get type among %r, name=None." % (
1291 outputs, ))
1292 if isinstance(name, Variable):
1293 n = name.name
1294 else:
1295 n = name
1296 if n not in outputs:
1297 return None
1298 return outputs[n]
1299 if isinstance(outputs, list):
1300 raise NotImplementedError( # pragma: no cover
1301 "Unexpected type for name=%r, outputs=%r." % (
1302 name, outputs))
1303 if is_numpy_dtype(outputs):
1304 return outputs
1305 raise RuntimeError( # pragma: no cover
1306 "Unable to handle outputs=%r." % outputs)
1308 @staticmethod
1309 def _node_to_graph_reorder_by_name(new_inputs, inputs):
1310 memo = OrderedDict((n.name, n) for n in new_inputs)
1311 done = set()
1312 result = []
1313 for inp in inputs:
1314 if inp.name in memo:
1315 result.append(memo[inp.name])
1316 done.add(inp.name)
1317 for k, v in memo.items():
1318 if k in done:
1319 continue
1320 result.append(v)
1321 return result
1323 def _node_to_graph(self, other_outputs=None, inputs=None, outputs=None,
1324 as_function=False):
1325 """
1326 Builds a graph as a list of nodes to walk through in that order.
1327 """
1329 node_outputs = [self]
1330 if other_outputs is not None:
1331 node_outputs += other_outputs
1333 logger.debug("%s._node_to_graph:inputs=%r",
1334 self.__class__.__name__, inputs)
1335 logger.debug("%s._node_to_graph:outputs=%r",
1336 self.__class__.__name__, outputs)
1338 # preprocess inputs, outputs
1339 _keep_inputs = None
1340 inputs_dtype = None
1341 if isinstance(inputs, list):
1342 _keep_inputs = inputs
1343 inputs_dict = self._node_to_graph_preprocess_list(inputs)
1344 elif isinstance(inputs, dict):
1345 inputs_dict = inputs
1346 elif isinstance(inputs, Variable):
1347 inputs = [inputs]
1348 inputs_dict = self._node_to_graph_preprocess_list(inputs)
1349 elif is_numpy_dtype(inputs):
1350 inputs_dtype = inputs
1351 inputs_dict = None
1352 else:
1353 raise TypeError( # pragma: no cover
1354 "Unexpected type %r for inputs." % type(inputs))
1356 _keep_outputs = None
1357 outputs_dtype = None
1358 if isinstance(outputs, list):
1359 _keep_outputs = outputs
1360 outputs_dict = self._node_to_graph_preprocess_list(outputs)
1361 elif isinstance(outputs, dict):
1362 outputs_dict = outputs
1363 elif isinstance(outputs, Variable):
1364 outputs = [outputs]
1365 outputs_dict = self._node_to_graph_preprocess_list(outputs)
1366 elif is_numpy_dtype(outputs):
1367 outputs_dtype = outputs
1368 outputs_dict = None
1369 else:
1370 raise TypeError( # pragma: no cover
1371 "Unexpected type %r for outputs." % type(outputs))
1373 logger.debug("%s._node_to_graph:inputs=%r",
1374 self.__class__.__name__, inputs)
1375 logger.debug("%s._node_to_graph:outputs=%r",
1376 self.__class__.__name__, outputs)
1377 logger.debug("%s._node_to_graph:inputs_dict=%r",
1378 self.__class__.__name__, inputs_dict)
1379 logger.debug("%s._node_to_graph:outputs_dict=%r",
1380 self.__class__.__name__, outputs_dict)
1381 logger.debug("%s._node_to_graph:inputs_dtype=%r",
1382 self.__class__.__name__, inputs_dtype)
1383 logger.debug("%s._node_to_graph:outputs_dtype=%r",
1384 self.__class__.__name__, outputs_dtype)
1386 # walk through graph
1387 stack = list(node_outputs)
1388 new_inputs = []
1389 set_inputs = set()
1390 memo = []
1391 while len(stack) > 0:
1392 memo.extend(stack)
1393 new_stack = []
1394 for obj in stack:
1395 if isinstance(obj, OnnxOperatorItem):
1396 # nothing to do, OnnxOperatorItem is created
1397 # by OnnxOperator.__getitem__.
1398 pass
1399 elif isinstance(obj, (OnnxOperator, OnnxOperatorTuple)):
1400 for inp in obj.inputs:
1401 self._node_to_graph_process_input(
1402 inputs_dict, set_inputs, obj, inp, new_inputs,
1403 new_stack, inputs_dtype, as_function=as_function)
1404 else:
1405 raise TypeError( # pragma: no cover
1406 "Unexpected type %r." % type(obj))
1407 stack = new_stack
1409 # reorder new_inputs to follow inputs initial order
1410 if _keep_inputs is not None:
1411 new_inputs = self._node_to_graph_reorder_by_name(
1412 new_inputs, inputs)
1414 logger.debug("%s._node_to_graph:new_inputs=%r",
1415 self.__class__.__name__, new_inputs)
1417 # eliminate duplicates
1418 done = set()
1419 nodes = []
1420 for node in reversed(memo):
1421 if id(node) in done:
1422 continue
1423 done.add(id(node))
1424 nodes.append(node)
1426 # outputs
1427 set_names = set()
1428 new_outputs = []
1429 run_shape = False
1430 for node in node_outputs:
1431 if node.output_names is None:
1432 n = self.output_range[0]
1433 for i in range(n):
1434 to = self._node_to_graph_get_type(
1435 node, outputs=outputs_dict,
1436 outputs_dtype=outputs_dtype)
1437 if to is None:
1438 run_shape = True
1439 res = '???_%d' % i
1440 var = Variable(res, added_dtype=to)
1441 if var.name in set_names:
1442 raise RuntimeError( # pragma: no cover
1443 "Duplicated output name var=%r." % var)
1444 set_names.add(var.name)
1445 new_outputs.append(OutputDetectedVariable(node, var, i))
1446 else:
1447 for i, o in enumerate(node.output_names):
1448 if isinstance(o, str):
1449 raise TypeError( # pragma: no cover
1450 "Output %d - %r (%r) not allowed in node %r." % (
1451 i, o, node.output_names, node))
1452 to = self._node_to_graph_get_type(
1453 node, o, outputs=outputs_dict,
1454 outputs_dtype=outputs_dtype)
1455 if to is None:
1456 run_shape = True
1457 res = (o, to)
1458 var = o.copy_merge(to)
1459 if var.name in set_names:
1460 raise RuntimeError( # pragma: no cover
1461 "Duplicated output name o=%r var=%r." % (o, var))
1462 set_names.add(var.name)
1463 new_outputs.append(OutputDetectedVariable(node, var, i))
1464 if len(new_outputs) == 0:
1465 raise RuntimeError( # pragma: no cover
1466 "No detected outputs inputs=%r outputs=%r." % (
1467 inputs_dict, outputs_dict))
1469 # reorder new_outputs to follow outputs initial order
1470 if _keep_outputs is not None:
1471 new_outputs = self._node_to_graph_reorder_by_name(
1472 new_outputs, outputs)
1474 logger.debug("%s._node_to_graph:new_outputs=%r",
1475 self.__class__.__name__, new_outputs)
1477 return nodes, new_inputs, new_outputs, run_shape
1479 def to_onnx(self, inputs=None, outputs=None,
1480 other_outputs=None, target_opset=None,
1481 optim=True, verbose=0, run_shape=True,
1482 function_name=None, function_domain=None,
1483 fLOG=print):
1484 """
1485 Converts this operator into an ONNX graph.
1487 :param inputs: information about type, it should not be None
1488 :param outputs: information about types, if None, the function
1489 will use shape inference to guess the final output type
1490 and shape
1491 :param other_outputs: additional nodes to consider
1492 as graph outputs but not outputs of this particular
1493 node
1494 :param target_opset: dictionary with target opset per domain,
1495 None for the default one
1496 :param optim: optimize the model with function
1497 @see fn onnx_optimisations
1498 :param run_shape: in case output shapes are not specify,
1499 the function runs function :epkg:`infer_shapes`
1500 to guess them, False would disable that
1501 default behaviour
1502 :param verbose: prints information
1503 :param function_name: if not None, returns a :epkg:`FunctionProto`
1504 :param function_domain: in case of a function, declares the function
1505 as part of this domain
1506 :param fLOG: logging function
1507 :return ONNX stucture
1508 """
1509 # opsets
1510 logger.debug(
1511 "%s.to_onnx(%r, %r, other_outputs=%r, target_opset=%r, as_function=%r)",
1512 self.__class__.__name__, inputs, outputs,
1513 other_outputs, target_opset, function_name)
1514 if isinstance(target_opset, dict):
1515 dom = self.domain or ''
1516 target_opset = target_opset.get(dom, None)
1517 elif isinstance(target_opset, int):
1518 if self.domain not in ('', None):
1519 # The target_opset is for the domain '' we ignore it.
1520 target_opset = None
1521 elif target_opset is not None:
1522 raise TypeError( # pragma: no cover
1523 "target_opset must be a dictionary {domain: "
1524 "target_opset} not %r for operator %r." % (
1525 target_opset, self.__class__.__name__))
1527 if self.domain in ('', None) and target_opset == 1:
1528 raise RuntimeError( # pragma: no cover
1529 "target_opset cannot be 1.")
1530 if (self.op_version is not None and target_opset is not None and
1531 self.op_version > target_opset):
1532 raise RuntimeError( # pragma: no cover
1533 "target_opset={} is lower than the version={} requested "
1534 "for this node '{}'.".format(
1535 target_opset, self.op_version, self.__class__.__name__))
1537 # get the graph
1538 nodes, graph_inputs, graph_outputs, run_shape2 = self._node_to_graph(
1539 other_outputs, inputs, outputs, as_function=function_name is not None)
1540 logger.debug("%s.to_onnx:graph_inputs=%r",
1541 self.__class__.__name__, graph_inputs)
1542 logger.debug("%s.to_onnx:graph_outputs=%r",
1543 self.__class__.__name__, graph_outputs)
1544 if len(nodes) == 0:
1545 raise RuntimeError( # pragma: no cover
1546 "Node list is empty.")
1547 if verbose > 1:
1548 for i, n in enumerate(nodes): # pragma: no cover
1549 fLOG("nodes[%d]=%r" % (i, n))
1550 for i, n in enumerate(graph_inputs): # pragma: no cover
1551 fLOG("graph_inputs[%d]=%r" % (i, n))
1553 # creates a _GraphBuilder
1554 builder = _GraphBuilder()
1556 # reserve input names starting by the first one
1557 for node in reversed(nodes):
1558 for var in node.inputs:
1559 if isinstance(var, Variable):
1560 logger.debug("%s.to_onnx:_add_name(%r)",
1561 self.__class__.__name__, var.name)
1562 builder._add_name(var.name)
1564 # reserve output names starting by the last ones
1565 for node in reversed(nodes):
1566 builder.reserve_names(node, node.output_names)
1568 # adds every node to the builder
1569 for i, node in enumerate(nodes):
1570 logger.debug("%s.to_onnx:node:%d/%d:%r",
1571 self.__class__.__name__, i, len(nodes), node)
1573 for node in nodes:
1574 node.add_to(builder)
1576 return builder.to_onnx(
1577 inputs=graph_inputs, outputs=graph_outputs,
1578 target_opset=target_opset, verbose=verbose,
1579 optim=optim, run_shape=run_shape and run_shape2,
1580 function_name=function_name, function_domain=function_domain)
1582 def predecessors(self):
1583 """
1584 Returns the list of predecessors.
1586 :return: list of @see cl OnnxOperator
1587 """
1588 stack = [self]
1589 last = 0
1590 while True:
1591 end = len(stack)
1592 if end == last:
1593 break
1594 for i in range(last, end):
1595 node = stack[i]
1596 for inp in node.inputs:
1597 if isinstance(inp, OnnxOperatorBase):
1598 stack.append(inp)
1599 last = end
1600 return stack
1602 def __call__(self, *args, function_name=None, function_domain=None,
1603 **kwargs):
1604 """
1605 Creates an instance of class @see cl OnnxOperatorFunction.
1606 Equivalent to `OnnxOperatorFunction(proto, *args, **kwargs)`.
1608 :param args: see @see cl OnnxOperatorFunction
1609 :param function_name: name to be given to the function
1610 :param function_domain: function domain, if None,
1611 it is given a default value
1612 :param kwargs: see @see cl OnnxOperatorFunction
1613 :return: instance of type @see cl OnnxOperatorFunction
1614 """
1615 if function_name is None:
1616 def clean(name):
1617 if name.startswith("Onnx"):
1618 name = name[4:]
1619 return name
1621 pred = self.predecessors()
1622 cls = [clean(p.__class__.__name__) for p in pred]
1623 function_name = "".join(cls)
1624 onx = self.to_onnx(function_name=function_name,
1625 function_domain=function_domain)
1626 return OnnxOperatorFunction(onx, *args, **kwargs)
1628 def find_named_inputs(self):
1629 """
1630 Retrieves all named inputs in this graph.
1631 """
1632 unique = set()
1633 found = []
1634 for inp in self.inputs:
1635 if isinstance(inp, str):
1636 if inp not in unique:
1637 found.append(inp)
1638 unique.add(inp)
1639 elif isinstance(inp, Variable):
1640 if inp.name not in unique:
1641 found.append(inp.name)
1642 unique.add(inp.name)
1643 elif isinstance(inp, OnnxOperatorBase):
1644 f = inp.find_named_inputs()
1645 for n in f:
1646 if n not in unique:
1647 found.append(n)
1648 unique.add(n)
1649 elif isinstance(inp, numpy.ndarray):
1650 pass
1651 else:
1652 raise RuntimeError( # pragma: no cover
1653 "Unexpected input type %r." % type(inp))
1654 return found
1656 def to_onnx_this(self, evaluated_inputs):
1657 """
1658 Returns a simple ONNX graph corresponding to this node.
1660 :param evaluated_inputs: inputs as a list
1661 :return: ONNX graph
1662 """
1663 inputs_names = ['I%d' % i for i in range(len(evaluated_inputs))]
1664 if self.output_names is None:
1665 if self.expected_outputs is None:
1666 raise NotImplementedError(
1667 "expected_outputs and output_names are not defined.")
1668 output_names = [o[0] for o in self.expected_outputs]
1669 else:
1670 output_names = [o.name for o in self.output_names]
1671 node = make_node(self.op_type, inputs_names, output_names,
1672 domain=self.domain, name="f", **self.kwargs)
1673 onx_inputs = [Variable(name, a.dtype).make_value_info()
1674 for name, a in zip(inputs_names, evaluated_inputs)]
1675 onx_outputs = [make_value_info(name, make_tensor_type_proto(0, []))
1676 for name in output_names]
1677 graph = make_graph([node], 'f', onx_inputs, onx_outputs)
1678 model = make_model(
1679 graph, opset_imports=[make_operatorsetid(
1680 self.domain or '', self.since_version)])
1681 return model
1683 def run(self, *inputs, verbose=0, fLOG=None, clear_cache=False, runtime=None):
1684 """
1685 Other name for
1686 `OnnxInference.f <mlprodict.onnxrt.onnx_inference.OnnxInference.f>`_.
1687 """
1688 return self.f(*inputs, verbose=verbose, fLOG=fLOG,
1689 clear_cache=clear_cache, runtime=runtime)
1691 def f(self, *inputs, verbose=0, fLOG=None, # pylint: disable=W0221
1692 clear_cache=False, runtime=None):
1693 """
1694 Computes the predictions for this node.
1695 Similar to an eager evaluation.
1697 :param inputs: inputs as dictionary or a list of inputs
1698 (see below)
1699 :param verbose: display information while predicting
1700 :param fLOG: logging function if *verbose > 0*
1701 :param clear_cache: onnx graph is created once unless
1702 this parameter is True
1703 :param runtime: runtime to use for the evaluation,
1704 see @see cl OnnxInference
1705 :return: outputs as a dictionary if the input were given as a
1706 dictionary or a single result or a tuple otherwise
1708 The inputs refer to the inputs of the graph.
1709 The method walks through all inputs and finds inputs defined as
1710 string. It replaces them by the value found in the dictionary.
1711 If the inputs are specified in a list, the function retrieves the
1712 list of inputs defined as a string and assigns them a value.
1713 Logging function can be used to get more insight about it.
1714 During the evaluation every node is independently converted
1715 into ONNX. The ONNX graph is cached in the class itself.
1716 """
1717 # input evaluation
1718 if len(inputs) == 1 and isinstance(inputs[0], dict):
1719 dict_inputs = inputs[0]
1720 as_dict = True
1721 elif not isinstance(inputs, (tuple, list)):
1722 raise TypeError( # pragma: no cover
1723 "inputs must be a list not %r." % type(inputs))
1724 elif len(inputs) > 0 and isinstance(inputs[0], OnnxOperator):
1725 raise TypeError( # pragma: no cover
1726 "Unexpected type for inputs[0]: %r." % type(inputs[0]))
1727 else:
1728 as_dict = False
1729 if verbose > 0:
1730 fLOG( # pragma: no cover
1731 "[OnnxOperator.f] retrieves named inputs")
1732 if hasattr(self, "feval_named_inputs_"):
1733 named_inputs = self.feval_named_inputs_ # pylint: disable=E0203
1734 else:
1735 named_inputs = self.find_named_inputs()
1736 self.feval_named_inputs_ = named_inputs
1737 if len(named_inputs) != len(inputs):
1738 raise RuntimeError(
1739 "Mismatch between the number of found inputs (%d) and "
1740 "the number of given inputs (%d) (found %r)."
1741 "" % (
1742 len(named_inputs), len(inputs), named_inputs))
1743 dict_inputs = {
1744 name: value for name, value in zip(named_inputs, inputs)}
1745 if verbose > 0:
1746 fLOG( # pragma: no cover
1747 "[OnnxOperator.f] found inputs: %r" % (named_inputs, ))
1749 # conversion
1750 evaluated_inputs = []
1751 for i, inp in enumerate(self.inputs):
1752 if isinstance(inp, str):
1753 evaluated_inputs.append(dict_inputs[inp])
1754 elif isinstance(inp, Variable):
1755 evaluated_inputs.append(dict_inputs[inp.name])
1756 elif isinstance(inp, OnnxOperatorBase):
1757 if verbose > 0:
1758 fLOG( # pragma: no cover
1759 "[OnnxOperator.f] evaluate input %d (op_type=%r)" % (
1760 i, self.__class__.op_type))
1761 out = inp.f(dict_inputs, verbose=verbose, fLOG=fLOG)
1762 if isinstance(out, dict):
1763 if len(out) == 1:
1764 evaluated_inputs.append(out.popitem()[1])
1765 else:
1766 raise NotImplementedError(
1767 "Not yet implemented in case when there are multiple "
1768 "outputs (%r)." % list(out))
1769 elif isinstance(out, list):
1770 evaluated_inputs.extend(out)
1771 else:
1772 evaluated_inputs.append(out)
1773 elif isinstance(inp, numpy.ndarray):
1774 evaluated_inputs.append(inp)
1775 else:
1776 raise RuntimeError( # pragma: no cover
1777 "Unexpected type %r for input %d." % (type(inp), i))
1779 # conversion to ONNX
1780 if not hasattr(self, 'feval_onnx_'):
1781 self.feval_onnx_ = {}
1782 key = tuple((m.dtype, m.shape) for m in evaluated_inputs)
1783 if key not in self.feval_onnx_ or clear_cache:
1784 if verbose > 0:
1785 fLOG("[OnnxOperator.f] creating node %r, inputs=%r" % (
1786 self.op_type, key))
1787 from ..onnxrt import OnnxInference
1788 model = self.to_onnx_this(evaluated_inputs)
1789 oinf = OnnxInference(model, runtime=runtime)
1790 self.feval_onnx_[key] = oinf
1791 else:
1792 oinf = self.feval_onnx_[key]
1794 # execution
1795 if verbose > 0:
1796 fLOG("[OnnxOperator.f] execute node %r" % self.op_type)
1797 got = oinf.run({k: v for k, v in
1798 zip(oinf.input_names, evaluated_inputs)})
1799 if as_dict:
1800 return got
1801 if len(got) == 1:
1802 return got.popitem()[1]
1803 return [got[n] for n in oinf.output_names]
1805 @staticmethod
1806 def _merge_op_version(n1, n2):
1807 if isinstance(n2, OnnxOperator):
1808 if n1.op_version is None:
1809 opv = n2.op_version
1810 elif n2.op_version is None:
1811 opv = n1.op_version
1812 elif n1.op_version == n2.op_version:
1813 opv = n1.op_version
1814 else:
1815 opv = max(n1.op_version, n2.op_version)
1816 elif isinstance(n2, OnnxOperatorItem):
1817 opv = OnnxOperator._merge_op_version(n1, n2.onx_op)
1818 elif isinstance(n2, OnnxOperatorTuple):
1819 raise NotImplementedError( # pragma: no cover
1820 "_merge_op_version is not implemented when n2 "
1821 "is OnnxOperatorTuple.")
1822 else:
1823 opv = n1.op_version
1824 return opv
1826 def __add__(self, ov):
1827 """
1828 Automatically adds operator `OnnxAdd` to the graph.
1830 :param ov: onnx node
1831 :return: `OnnxAdd(self, ov)`
1832 """
1833 OnnxAdd = loadop('Add')
1834 opv = self._merge_op_version(self, ov)
1835 return OnnxAdd(self, ov, op_version=opv)
1837 def __sub__(self, ov):
1838 """
1839 Automatically adds operator `OnnxSub` to the graph.
1841 :param ov: onnx node
1842 :return: `OnnxSub(self, ov)`
1843 """
1844 OnnxSub = loadop('Sub')
1845 opv = self._merge_op_version(self, ov)
1846 return OnnxSub(self, ov, op_version=opv)
1848 def __mul__(self, ov):
1849 """
1850 Automatically adds operator `OnnxMul` to the graph.
1852 :param ov: onnx node
1853 :return: `OnnxMul(self, ov)`
1854 """
1855 OnnxMul = loadop('Mul')
1856 opv = self._merge_op_version(self, ov)
1857 return OnnxMul(self, ov, op_version=opv)
1859 def __truediv__(self, ov):
1860 """
1861 Automatically adds operator `OnnxDiv` to the graph.
1863 :param ov: onnx node
1864 :return: `OnnxDiv(self, ov)`
1865 """
1866 OnnxDiv = loadop('Div')
1867 opv = self._merge_op_version(self, ov)
1868 return OnnxDiv(self, ov, op_version=opv)
1870 def __pow__(self, ov):
1871 """
1872 Automatically adds operator `OnnxPow` to the graph.
1874 :param ov: onnx node
1875 :return: `OnnPow(self, ov)`
1876 """
1877 OnnxPow = loadop('Pow')
1878 opv = self._merge_op_version(self, ov)
1879 return OnnxPow(self, ov, op_version=opv)
1881 def __mod__(self, ov):
1882 """
1883 Automatically adds operator `OnnxMod` to the graph.
1885 :param ov: onnx node
1886 :return: `OnnxMod(self, ov)`
1887 """
1888 OnnxMod = loadop('Mod')
1889 opv = self._merge_op_version(self, ov)
1890 return OnnxMod(self, ov, op_version=opv)
1892 def __matmul__(self, ov):
1893 """
1894 Automatically adds operator `OnnxMatMul` to the graph.
1896 :param ov: onnx node
1897 :return: `OnnMatMul(self, ov)`
1898 """
1899 OnnxMatMul = loadop('MatMul')
1900 opv = self._merge_op_version(self, ov)
1901 return OnnxMatMul(self, ov, op_version=opv)
1903 def __gt__(self, ov):
1904 """
1905 Automatically adds operator `OnnxGreater` to the graph.
1907 :param ov: onnx node
1908 :return: `OnnxGreater(self, ov)`
1909 """
1910 OnnxGreater = loadop('Greater')
1911 opv = self._merge_op_version(self, ov)
1912 return OnnxGreater(self, ov, op_version=opv)
1914 def __lt__(self, ov):
1915 """
1916 Automatically adds operator `OnnxLess` to the graph.
1918 :param ov: onnx node
1919 :return: `OnnxLess(self, ov)`
1920 """
1921 OnnxLess = loadop('Less')
1922 opv = self._merge_op_version(self, ov)
1923 return OnnxLess(self, ov, op_version=opv)
1925 def __eq__(self, ov):
1926 """
1927 Automatically adds operator `OnnxEqual` to the graph.
1929 :param ov: onnx node
1930 :return: `OnnxEqual(self, ov)`
1931 """
1932 OnnxEqual = loadop('Equal')
1933 opv = self._merge_op_version(self, ov)
1934 return OnnxEqual(self, ov, op_version=opv)
1936 def and_(self, ov):
1937 """
1938 Automatically adds operator `OnnxAnd` to the graph.
1940 :param ov: onnx node
1941 :return: `OnnxAnd(self, ov)`
1942 """
1943 OnnxAnd = loadop('And')
1944 opv = self._merge_op_version(self, ov)
1945 return OnnxAnd(self, ov, op_version=opv)
1947 def or_(self, ov):
1948 """
1949 Automatically adds operator `OnnxOr` to the graph.
1951 :param ov: onnx node
1952 :return: `OnnxOr(self, ov)`
1953 """
1954 OnnxOr = loadop('Or')
1955 opv = self._merge_op_version(self, ov)
1956 return OnnxOr(self, ov, op_version=opv)
1958 def __ne__(self, ov):
1959 """
1960 Automatically adds operator `OnnxNot x OnnxEqual` to the graph.
1962 :param ov: onnx node
1963 :return: `OnnxNot(OnnxEqual(self, ov))`
1964 """
1965 OnnxNot, OnnxEqual = loadop('Not', 'Equal')
1966 opv = self._merge_op_version(self, ov)
1967 return OnnxNot(OnnxEqual(self, ov, op_version=opv), op_version=opv)
1969 def __abs__(self):
1970 """
1971 Automatically adds operator `OnnxAbs` to the graph.
1973 :param ov: onnx node
1974 :return: `OnnxAbs(self, ov)`
1975 """
1976 OnnxAbs = loadop('Abs')
1977 return OnnxAbs(self, op_version=self.op_version)
1979 def not_(self):
1980 """
1981 Automatically adds operator `OnnxNot` to the graph.
1983 :param ov: onnx node
1984 :return: `OnnxNot(self, ov)`
1985 """
1986 OnnxNot = loadop('Not')
1987 return OnnxNot(self, op_version=self.op_version)
1989 def astype(self, to):
1990 """
1991 Automatically adds operator `OnnxCast` to the graph.
1993 :param ov: onnx node
1994 :return: `OnnxCast(self, ov, to=to)`
1995 """
1996 OnnxCast = loadop('Cast')
1997 return OnnxCast(self, to=to, op_version=self.op_version)
2000class OnnxOperatorFunction(OnnxOperator):
2001 """
2002 This operator is used to insert existing ONNX function into
2003 the ONNX graph being built.
2004 """
2006 domain = 'mlprodict'
2007 since_version = 1
2008 expected_inputs = None
2009 expected_outputs = None
2010 input_range = [1, 1e9]
2011 output_range = [1, 1e9]
2012 op_type = 'Function'
2013 domain = 'mlprodict.xop'
2015 @staticmethod
2016 def attribute_to_value(att):
2017 """
2018 Converts an attribute into a value using python structures.
2019 """
2020 if isinstance(att, onnx.AttributeProto):
2021 dtype = att.type
2022 else:
2023 raise NotImplementedError( # pragma: no cover
2024 "Unable to copy attribute type %r." % type(att))
2025 if dtype == 1: # .f
2026 value = att.f
2027 elif dtype == 2: # .i
2028 value = att.i
2029 elif dtype == 3: # .s
2030 value = att.s
2031 elif dtype == 4: # .t
2032 value = att.t
2033 elif dtype == 6: # .floats
2034 value = list(att.floats)
2035 elif dtype == 7: # .ints
2036 value = list(att.ints)
2037 elif dtype == 8: # .strings
2038 value = list(att.strings)
2039 elif dtype == 11: # .double_data
2040 value = list(att.double_data)
2041 else:
2042 raise NotImplementedError( # pragma: no cover
2043 "Unable to copy attribute type %r (%r)." % (
2044 dtype, att))
2045 return value
2047 def __init__(self, function_proto, *inputs, output_names=None):
2048 logger.debug("Function(ONNX, %d in, output_names=%r)",
2049 len(inputs), output_names)
2050 if function_proto is None:
2051 raise ValueError(
2052 "function_proto cannot be None.") # pragma: no cover
2053 if not isinstance(function_proto, onnx.FunctionProto):
2054 raise TypeError( # pragma: no cover
2055 "function_proto must be of type FunctionProto not %r." %
2056 type(function_proto))
2057 if len(inputs) > len(function_proto.input):
2058 raise RuntimeError( # pragma: no cover
2059 "Unexpected number of inputs %r > expected %r." % (
2060 len(inputs), len(function_proto.input)))
2061 if (output_names is not None and
2062 len(output_names) != len(function_proto.output)):
2063 raise RuntimeError( # pragma: no cover
2064 "Unexpected number of outputs %r != expected %r." % (
2065 len(output_names), len(function_proto.output)))
2066 OnnxOperator.__init__(self, *inputs, output_names=output_names)
2067 self.model = function_proto
2069 def __repr__(self):
2070 "usual"
2071 atts = {}
2072 for att in ['output_names']:
2073 value = getattr(self, att, None)
2074 if value is not None:
2075 atts[att] = value
2076 atts.update(self.kwargs)
2077 msg = ", ".join("%s=%r" % (k, v) for k, v in atts.items())
2078 if len(atts) > 0:
2079 msg = ", " + msg
2080 return "%s(...%s)" % (
2081 self.__class__.__name__, msg)
2083 def add_to(self, builder):
2084 """
2085 Adds to graph builder.
2087 :param builder: instance of @see cl _GraphBuilder,
2088 it must have a method `add_node`
2089 """
2090 logger.debug("Function.add_to(builder)")
2091 inputs = builder.get_input_names(self, self.inputs)
2092 n_outputs = len(self.model.output)
2093 outputs = [builder.get_unique_output_name(NodeResultName(self, i))
2094 for i in range(n_outputs)]
2096 # linking inputs
2097 builder.add_function(self.model)
2098 builder.add_node(
2099 self.model.name, builder.get_unique_name(
2100 '_fct_' + self.model.name, reserved=False),
2101 inputs, outputs, domain=self.model.domain)
2104class _GraphBuilder:
2105 """
2106 Graph builder. It takes a graph structure made with
2107 instances of @see cl OnnxOperatorBase.
2108 The main method is `to_onnx`.
2110 * `initializer`: list of initializers to add to the ONNX graph
2111 * `node`: list of nodes to add to the ONNX graph
2112 * `input`: list of inputs to add to the ONNX graph
2113 * `output`: list of inputs to add to the ONNX graph
2114 * `opsets`: opsets of the ONNX graph
2115 * `input_names`: dictionary of input names
2116 `{name: InputDetectedVariable}`
2117 * `node_output_names`: memorizes a name for a node output
2118 when the user did not specify any
2119 `{(id(node), index): OutputDetectedVariable}`
2120 * `reserved_names`: dictionary `{ name : (node, index) }`,
2121 name which should remain unchanged in the ONNX graph
2122 * `names`: list of uniques names
2123 * `functions`: dictionary `{ domain, name: function_proto }`
2124 * `function_hashes`: dictionary `{ domain, name: hash of function_proto }`
2125 """
2127 def __init__(self):
2128 self.initializer = []
2129 self.node = []
2130 self.input = []
2131 self.output = []
2132 self.opsets = {}
2133 self.input_names = {}
2134 self.node_output_names = {}
2135 self.reserved_names = {}
2136 self.names = set()
2137 self.functions = {}
2138 self.function_hashes = {}
2140 def _add_name(self, name):
2141 self.names.add(name)
2143 @staticmethod
2144 def number2alpha(index):
2145 """
2146 Converts a numbers into a string keeping the same
2147 alphabetical order.
2148 """
2149 dec = str(int(index))
2150 if len(dec) == 1:
2151 return dec
2152 return chr(96 + len(dec)) + dec
2154 def reserve_names(self, node, output_names):
2155 """
2156 Adds names to the list of reserved names.
2157 All must be unique.
2159 :param node: node or None for an input
2160 :param output_names: names of the output
2161 """
2162 if output_names is None:
2163 return
2164 for index, var in enumerate(output_names):
2165 if not isinstance(var, Variable):
2166 raise TypeError( # pragma: no cover
2167 "Unexpected type %r for %r." % (type(var), var))
2168 self.reserve_name(node, var.name, index)
2170 def reserve_name(self, node, name, index):
2171 """
2172 Reserves a name so that it cannot be changed.
2174 :param node: node or None for an input
2175 :param name: name
2176 :param index: input index
2177 """
2178 if not isinstance(name, str):
2179 raise TypeError( # pragma: no cover
2180 "Name %r is not a string." % (name, ))
2181 if name in self.reserved_names:
2182 raise RuntimeError( # pragma: no cover
2183 "Name %r is already reserved from node %r, index=%d." % (
2184 name, node, index))
2185 logger.debug("_GraphBuilder.reserve_name([%s-%d], %r, %r)",
2186 node.__class__.__name__, id(node),
2187 name, index)
2188 self.reserved_names[name] = (node, index)
2189 self._add_name(name)
2191 def get_unique_output_name(self, result):
2192 """
2193 Returns a unique output_name for a NodeResultName.
2195 :param result: instance of @see cl NodeResultName
2196 """
2197 if not isinstance(result, NodeResultName):
2198 raise TypeError( # pragma: no cover
2199 "Result must be of type NodeResultName not %r (%r)." % (
2200 type(result), result))
2201 if result.node is None:
2202 key = None, result.index
2203 else:
2204 key = id(result.node), result.index
2205 if key in self.node_output_names:
2206 return self.node_output_names[key]
2207 name = result.get_name()
2208 if name in self.reserved_names:
2209 unique = name
2210 else:
2211 unique = self.get_unique_name(name)
2212 self.node_output_names[key] = unique
2213 return unique
2215 def get_unique_name(self, name, reserved=True):
2216 """
2217 Returns a unique name to name an output.
2219 :param name: name
2220 :param reserved: bypass if the name is a reserved one
2221 :return: unique name, may be the same if not taken already
2222 """
2223 if not isinstance(name, str):
2224 raise TypeError( # pragma: no cover
2225 "name must be a string not %r." % type(name))
2226 if reserved and name in self.reserved_names:
2227 logger.debug( # pragma: no cover
2228 "_GraphBuilder.get_unique_name(%r) 1-> %r", name, name)
2229 return name
2230 if name not in self.names:
2231 self._add_name(name)
2232 logger.debug("_GraphBuilder.get_unique_name(%r) 2-> %r",
2233 name, name)
2234 return name
2235 i = 1
2236 new_name = "%s_%s" % (name, self.number2alpha(i))
2237 while new_name in self.names:
2238 i += 1
2239 new_name = "%s_%s" % (name, self.number2alpha(i))
2240 self._add_name(new_name)
2241 logger.debug("_GraphBuilder.get_unique_name(%r) 3-> %r",
2242 name, new_name)
2243 return new_name
2245 def get_input_names(self, node, inputs):
2246 """
2247 Returns input names for node *node* and inputs *inputs*.
2249 :param node: node
2250 :param inputs: inputs
2251 :return: name
2252 """
2253 names = []
2254 for i in inputs:
2255 if isinstance(i, Variable):
2256 self._add_name(i.name)
2257 names.append(i.name)
2258 self.input_names[i.name] = InputDetectedVariable(None, i)
2259 elif isinstance(i, OnnxOperator):
2260 key = id(i), 0
2261 try:
2262 name = self.node_output_names[key]
2263 except KeyError as e: # pragma: no cover
2264 raise RuntimeError(
2265 "Unable to find key %r for input %r in node %r." % (
2266 key, i, node)) from e
2267 names.append(name)
2268 elif isinstance(i, OnnxOperatorItem):
2269 if isinstance(i.onx_op, OnnxOperatorTuple):
2270 if i.onx_op.values is None:
2271 key = id(i.onx_op.unique), i.index
2272 else:
2273 key = id(i.onx_op[i.index]), 0
2274 elif isinstance(i.onx_op, OnnxOperator):
2275 key = id(i.onx_op), i.index
2276 else:
2277 raise TypeError( # pragma: no cover
2278 "Unexpected type for OnnxOperatorItem: %r." % type(i.onx_op))
2279 try:
2280 name = self.node_output_names[key]
2281 except KeyError as e: # pragma: no cover
2282 raise RuntimeError(
2283 "Unable to find key %r for input %r in node %r." % (
2284 key, i, node)) from e
2285 names.append(name)
2286 elif isinstance(i, OnnxOperatorTuple):
2287 raise NotImplementedError()
2288 elif isinstance(i, numpy.ndarray):
2289 # Adding an initializer
2290 name = self.get_unique_name('init', reserved=False)
2291 init = from_array(i, name)
2292 self.initializer.append(init)
2293 names.append(name)
2294 else:
2295 raise TypeError( # pragma: no cover
2296 "Unexpected type for an input %r." % type(i))
2297 return names
2299 def add_initializer(self, name, init):
2300 """
2301 Adds an initializer to the graph.
2303 :param name: initializer name
2304 :param init: initializer to copy
2305 :return: created intializer
2306 """
2307 if isinstance(init, onnx.TensorProto):
2308 tensor = to_array(init)
2309 val = from_array(tensor, name)
2310 logger.debug("_GraphBuilder.add_initializer:1(%r, %r, %r)",
2311 name, tensor.dtype, tensor.shape)
2312 elif isinstance(init, numpy.ndarray):
2313 value = to_array(init)
2314 val = from_array(value, name)
2315 logger.debug("_GraphBuilder.add_initializer:2(%r, %r, %r)",
2316 name, init.dtype, init.shape)
2317 else:
2318 raise NotImplementedError( # pragma: no cover
2319 "Unsupported initializer type %r." % type(init))
2320 self.initializer.append(val)
2321 return val
2323 def add_function(self, function_proto,
2324 raise_if_exist=False, check_unique=True,
2325 opset=1):
2326 """
2327 Adds a function to the graph.
2329 :param function_proto: instance of type :epkg:`FunctionProto`
2330 :param raise_if_exist: raises an exception if a function of the
2331 same name was already added
2332 :param check_unique: checks if a function was added twice,
2333 it is the same
2334 :param opset: opset for the domain the function belongs to
2335 """
2336 def _hash(p):
2337 m = hashlib.sha256()
2338 m.update(p.SerializeToString())
2339 return m.hexdigest()[:64]
2341 key = function_proto.domain, function_proto.name
2342 if key in self.functions:
2343 if raise_if_exist:
2344 raise RuntimeError( # pragma: no cover
2345 "Function %r is added for the second time." % (key, ))
2346 if check_unique:
2347 hs = _hash(function_proto)
2348 if hs != self.function_hashes[key]:
2349 raise RuntimeError( # pragma: no cover
2350 "Function %r is added for the second time "
2351 "and the content is not the same." % (key, ))
2352 return
2353 self.functions[key] = function_proto
2354 self.function_hashes[key] = _hash(function_proto)
2356 if function_proto.domain not in self.opsets:
2357 self.opsets[function_proto.domain] = opset
2358 else:
2359 self.opsets[function_proto.domain] = max(
2360 opset, self.opsets[function_proto.domain])
2362 def add_node(self, op_type, name, inputs, outputs, domain='',
2363 opset=None, **attributes):
2364 """
2365 Adds a node to the graph.
2367 :param op_type: operator type
2368 :param name: node name
2369 :param inputs: inputs name list
2370 :param outputs: outputs name list
2371 :param domain: node domain
2372 :param opset: node opset
2373 :return: created node
2374 """
2375 if domain is None:
2376 domain = ''
2377 logger.debug("_GraphBuilder.add_node(%r, %r, "
2378 "inputs=%r, outputs=%r, domain=%r, opset=%r)",
2379 op_type, name, inputs, outputs, domain, opset)
2380 if not isinstance(inputs, list):
2381 raise TypeError( # pragma: no cover
2382 "inputs must be a list not %r." % type(inputs))
2383 if not isinstance(outputs, list):
2384 raise TypeError( # pragma: no cover
2385 "inputs must be a list not %r." % type(outputs))
2386 if any(map(lambda x: not isinstance(x, str), inputs)):
2387 raise TypeError( # pragma: no cover
2388 "inputs must be all strings not %r." % inputs)
2389 if any(map(lambda x: not isinstance(x, str), outputs)):
2390 raise TypeError( # pragma: no cover
2391 "outputs must be all strings not %r." % outputs)
2392 if opset is not None:
2393 if domain not in self.opsets:
2394 self.opsets[domain] = opset
2395 else:
2396 self.opsets[domain] = max(opset, self.opsets[domain])
2397 node = make_node(op_type, inputs, outputs, name=name,
2398 domain=domain, **attributes)
2399 self.node.append(node)
2400 return node
2402 def _process_io(self, inputs, input_names):
2403 if inputs is None:
2404 return [
2405 make_tensor_value_info(
2406 'X', TensorProto.FLOAT, None) # pylint: disable=E1101
2407 for name in self.input_names]
2409 if not isinstance(inputs, list):
2410 if is_numpy_dtype(inputs):
2411 inputs = [inputs]
2413 if input_names is None:
2414 # outputs
2415 set_names = set()
2416 input_names = []
2417 new_inputs = []
2418 for inp in inputs:
2419 if isinstance(inp, OutputDetectedVariable):
2420 if inp.name in set_names:
2421 raise ValueError( # pragma: no cover
2422 "Names already taken %r in %r." % (
2423 inp.name, inputs))
2424 set_names.add(inp.name)
2425 key = id(inp.node), inp.index
2426 if key in self.node_output_names:
2427 new_name = self.node_output_names[key]
2428 new_var = OutputDetectedVariable(
2429 inp.node, inp.var.copy_name(new_name), inp.index)
2430 input_names.append(new_var)
2431 new_inputs.append(new_var)
2432 else:
2433 raise RuntimeError( # pragma: no cover
2434 "Key %r is ambiguous or defined in "
2435 "two nodes %r, id(node)=%d, index=%d." % (
2436 key, inp, id(inp.node), inp.index))
2437 else:
2438 raise TypeError( # pragma: no cover
2439 "Unexpected type %r (it should be "
2440 "OutputDetectedVariable) in %r." % (inp, inputs))
2441 inputs = new_inputs
2442 if len(input_names) == 0:
2443 raise RuntimeError( # pragma: no cover
2444 "Unable to cross %r and %r or %r (set_names=%r)." % (
2445 inputs, self.output_names_rev,
2446 self.node_output_names_rev, set_names))
2447 elif not isinstance(input_names, list):
2448 raise RuntimeError( # pragma: no cover
2449 "Unexpected type for input_names %r." % type(input_names))
2450 else:
2451 # inputs
2452 pass
2454 # common parts
2455 if len(input_names) != len(inputs):
2456 raise RuntimeError( # pragma: no cover
2457 "Mismatch between %r and %r." % (
2458 input_names, inputs))
2460 if isinstance(input_names, list):
2461 d_input_names = {}
2462 for inp in input_names:
2463 if inp.name in d_input_names:
2464 raise ValueError( # pragma: no cover
2465 "Duplicated name %r in %r." % (inp.name, input_names))
2466 d_input_names[inp.name] = inp
2467 elif isinstance(input_names, dict):
2468 d_input_names = input_names
2469 else:
2470 raise TypeError( # pragma: no cover
2471 "Unexpected type for input_names %r (%r)." % (
2472 type(input_names), input_names))
2474 # mapping
2475 res = []
2476 for inp in inputs:
2477 if not isinstance(inp, DetectedVariable):
2478 raise TypeError( # pragma: no cover
2479 "inp not DetectedVariable but %r (%r)"
2480 "." % (type(inp), inp))
2481 if inp.name.startswith('???'):
2482 raise RuntimeError( # pragma: no cover
2483 "Issue with variable %r." % inp)
2484 var = d_input_names[inp.name]
2485 if not isinstance(var, DetectedVariable):
2486 raise TypeError( # pragma: no cover
2487 "var not Variable but %r (%r)." % (
2488 type(var), var))
2489 # inp: Variable
2490 # var: str
2491 if inp.var != var.var:
2492 raise RuntimeError( # pragma: no cover
2493 "Unexpected %r != %r." % (inp, var))
2494 res.append(make_tensor_value_info(
2495 inp.name, inp.var.proto_added_type,
2496 inp.var.proto_added_shape))
2498 return res
2500 def to_onnx(self, inputs=None, outputs=None,
2501 target_opset=None, run_shape=False,
2502 optim=True, function_name=None,
2503 function_domain=None, verbose=0):
2504 """
2505 Converts this operator into an ONNX graph.
2507 :param inputs: specific inputs (as a dictionary) or
2508 default inputs if not specified
2509 :param outputs: specific outputs
2510 :param target_opset: dictionary with target opset per domain,
2511 None for the default one
2512 :param run_shape: run shape inference before returning the model
2513 :param optim: optimize the model with function
2514 @see fn onnx_optimisations
2515 :param function_name: if not None builds a :epkg:`FunctionProto`
2516 use this name
2517 :param function_domain: in case of a function, declares the function
2518 as part of this domain, `'mlprodict'` if None
2519 :param verbose: prints information
2520 :return: onnx graph
2521 """
2522 logger.debug("_GraphBuilder.to_onnx(%r, %r, target_opset=%r)",
2523 inputs, outputs, target_opset)
2524 # inputs and outputs
2525 if not all(map(lambda x: isinstance(x, InputDetectedVariable), inputs)):
2526 raise TypeError( # pragma: no cover
2527 "One of the input is not InputDetectedVariable.")
2528 if not all(map(lambda x: isinstance(x, OutputDetectedVariable), outputs)):
2529 raise TypeError( # pragma: no cover
2530 "One of the outputs is not OutputDetectedVariable.")
2531 self.input = self._process_io(inputs, list(self.input_names.values()))
2532 self.output = self._process_io(outputs, None)
2533 logger.debug("_GraphBuilder.to_onnx:self.input=%r",
2534 [i.name for i in self.input])
2535 logger.debug("_GraphBuilder.to_onnx:self.output=%r",
2536 [i.name for i in self.output])
2537 logger.debug("_GraphBuilder.to_onnx:build:n_inputs=%r n_inits=%r n_nodes=%r "
2538 "n_outputs=%r",
2539 len(self.input), len(self.initializer), len(self.node),
2540 len(self.output))
2542 if function_name is not None:
2543 if function_domain is None:
2544 function_domain = 'mlprodict'
2545 if len(self.initializer) > 0:
2546 nodes = []
2547 for init in self.initializer:
2548 nodes.append(
2549 make_node('Constant', [], [init.name], value=init,
2550 name='_init_%s' % init.name))
2551 nodes.extend(self.node)
2552 else:
2553 nodes = self.node
2554 fct = make_function(
2555 function_domain, function_name,
2556 [_.name for _ in self.input],
2557 [_.name for _ in self.output],
2558 nodes,
2559 [make_opsetid(k, v) for k, v in self.opsets.items()])
2560 if optim:
2561 from ..onnx_tools.optim import onnx_optimisations
2562 fct = onnx_optimisations(fct)
2563 return fct
2564 else:
2565 graph = make_graph(
2566 self.node, 'XOP', self.input, self.output, self.initializer)
2567 onnx_model = make_model(
2568 graph, functions=list(self.functions.values()))
2569 opv = self.opsets.get('', max_supported_opset())
2570 opset2ir = _default_OPSET_TO_IR_VERSION()
2571 irv = opset2ir.get(opv, max(opset2ir.values()))
2572 onnx_model.ir_version = irv
2574 logger.debug("_GraphBuilder.to_onnx:2onnx:n_inputs=%r n_inits=%r "
2575 "n_nodes=%r n_outputs=%r",
2576 len(onnx_model.graph.input),
2577 len(onnx_model.graph.initializer),
2578 len(onnx_model.graph.node),
2579 len(onnx_model.graph.output))
2581 del onnx_model.opset_import[:] # pylint: disable=E1101
2582 seen_opset = set()
2583 for k, v in self.opsets.items():
2584 if (k or '') in seen_opset:
2585 raise RuntimeError( # pragma: no cover
2586 "Duplicated opset (%r, %r)." % (k, v))
2587 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
2588 op_set.domain = k or ''
2589 op_set.version = v
2590 seen_opset.add(op_set.domain)
2592 # optimisation, remove redundant constant, unnecessary
2593 # identity nodes.
2594 if optim:
2595 from ..onnx_tools.optim import onnx_optimisations
2596 onnx_model = onnx_optimisations(onnx_model)
2598 logger.debug("_GraphBuilder.to_onnx:optim:n_inputs=%r n_inits=%r "
2599 "n_nodes=%r n_outputs=%r",
2600 len(onnx_model.graph.input),
2601 len(onnx_model.graph.initializer),
2602 len(onnx_model.graph.node),
2603 len(onnx_model.graph.output))
2605 if run_shape:
2606 with_shape = infer_shapes(onnx_model)
2607 logger.debug("_GraphBuilder.to_onnx:shape:n_inputs=%r "
2608 "n_inits=%r n_nodes=%r n_outputs=%r",
2609 len(with_shape.graph.input),
2610 len(with_shape.graph.initializer),
2611 len(with_shape.graph.node),
2612 len(with_shape.graph.output))
2613 return with_shape
2615 logger.debug("_GraphBuilder.to_onnx() -> done")
2616 return onnx_model
2619_all_schemas, _all_schemas_versions, _all_domains = _populate_schemas()
2620_all_classes = {}
2621onnx_load_factory = Xop = OnnxLoadFactory()