Coverage for mlprodict/onnxrt/ops_cpu/__init__.py: 87%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *ops_cpu*.
5"""
6import textwrap
7from ..excs import MissingOperatorError
8from ._op import OpRunCustom
9from ._op_list import __dict__ as d_op_list
12_additional_ops = {}
15def register_operator(cls, name=None, overwrite=True):
16 """
17 Registers a new runtime operator.
19 @param cls class
20 @param name by default ``cls.__name__``,
21 or *name* if defined
22 @param overwrite overwrite or raise an exception
23 """
24 if name is None:
25 name = cls.__name__
26 if name not in _additional_ops:
27 _additional_ops[name] = cls
28 elif not overwrite:
29 raise RuntimeError( # pragma: no cover
30 "Unable to overwrite existing operator '{}': {} "
31 "by {}".format(name, _additional_ops[name], cls))
34def load_op(onnx_node, desc=None, options=None, runtime=None):
35 """
36 Gets the operator related to the *onnx* node.
38 :param onnx_node: :epkg:`onnx` node
39 :param desc: internal representation
40 :param options: runtime options
41 :param runtime: runtime
42 :return: runtime class
43 """
44 from ... import __max_supported_opset__
45 if desc is None:
46 raise ValueError("desc should not be None.") # pragma no cover
47 name = onnx_node.op_type
48 opset = options.get('target_opset', None) if options is not None else None
49 current_opset = __max_supported_opset__
50 chosen_opset = current_opset
51 if opset == current_opset:
52 opset = None
53 if opset is not None:
54 if not isinstance(opset, int):
55 raise TypeError( # pragma no cover
56 "opset must be an integer not {}".format(type(opset)))
57 name_opset = name + "_" + str(opset)
58 for op in range(opset, 0, -1):
59 nop = name + "_" + str(op)
60 if nop in d_op_list:
61 name_opset = nop
62 chosen_opset = op
63 break
64 else:
65 name_opset = name
67 if name_opset in _additional_ops:
68 cl = _additional_ops[name_opset]
69 elif name in _additional_ops:
70 cl = _additional_ops[name]
71 elif name_opset in d_op_list:
72 cl = d_op_list[name_opset]
73 elif name in d_op_list:
74 cl = d_op_list[name]
75 else:
76 raise MissingOperatorError( # pragma no cover
77 "Operator '{}' from domain '{}' has no runtime yet. "
78 "Available list:\n"
79 "{} - {}".format(
80 name, onnx_node.domain,
81 "\n".join(sorted(_additional_ops)),
82 "\n".join(textwrap.wrap(
83 " ".join(
84 _ for _ in sorted(d_op_list)
85 if "_" not in _ and _ not in {'cl', 'clo', 'name'})))))
87 if hasattr(cl, 'version_higher_than'):
88 opv = min(current_opset, chosen_opset)
89 if cl.version_higher_than > opv:
90 # The chosen implementation does not support
91 # the opset version, we need to downgrade it.
92 if ('target_opset' in options and
93 options['target_opset'] is not None): # pragma: no cover
94 raise RuntimeError(
95 "Supported version {} > {} (opset={}) required version, "
96 "unable to find an implementation version {} found "
97 "'{}'\n--ONNX--\n{}\n--AVAILABLE--\n{}".format(
98 cl.version_higher_than, opv, opset,
99 options['target_opset'], cl.__name__, onnx_node,
100 "\n".join(
101 _ for _ in sorted(d_op_list)
102 if "_" not in _ and _ not in {'cl', 'clo', 'name'})))
103 options = options.copy()
104 options['target_opset'] = current_opset
105 return load_op(onnx_node, desc=desc, options=options)
107 if options is None:
108 options = {} # pragma: no cover
109 return cl(onnx_node, desc=desc, runtme=runtime, **options)