Coverage for mlprodict/onnxrt/ops_shape/__init__.py: 95%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Shortcut to *ops_shape*.
4"""
5import textwrap
6from onnx.onnx_cpp2py_export.defs import SchemaError # pylint: disable=E0401,E0611
7from ...onnx_tools.onnx2py_helper import get_onnx_schema
8from ._element_unary import (
9 shape_abs, shape_acos, shape_acosh,
10 shape_asin, shape_asinh, shape_atan, shape_atanh,
11 shape_castlike, shape_ceil, shape_celu,
12 shape_clip, shape_cos, shape_cosh,
13 shape_elu, shape_erf, shape_exp, shape_floor,
14 shape_hardmax, shape_hardsigmoid,
15 shape_identity, shape_isinf, shape_isnan,
16 shape_leakyrelu, shape_log, shape_logsoftmax,
17 shape_neg, shape_not, shape_reciprocal, shape_relu, shape_round,
18 shape_selu,
19 shape_sigmoid, shape_sign, shape_sin, shape_sinh, shape_softmax,
20 shape_sqrt, shape_tan, shape_tanh, shape_trilu)
21from ._element_wise import (
22 shape_add, shape_and,
23 shape_div,
24 shape_equal,
25 shape_greater, shape_greaterorequal,
26 shape_less, shape_lessorequal,
27 shape_max, shape_min, shape_mod, shape_mul,
28 shape_or,
29 shape_pow,
30 shape_sub,
31 shape_xor)
32from ._op_shape_op import shape_det
35_shape_functions = {
36 k: v for k, v in globals().items() if k.startswith("shape_")
37}
40count = [0]
43def shape_dispatch(cache, known_shape, node, rt_class=None):
44 """
45 Calls the corresponding fucntion for every node.
47 :param cache: cache used function
48 :param known_shape: known_shape for all results
49 :param node: onnx node
50 :param rt_class: a node may be a predefined function in onnx,
51 if no specific function is available, the predefined
52 onnx definition is used and run through this runtime
53 :return: was *known_shape* updated or not...
54 """
55 key = node.domain, node.op_type
56 fct_shape = None
57 if key in cache:
58 fct_shape = cache[key]
59 else:
60 op_type = "shape_" + node.op_type.lower()
61 if op_type in _shape_functions:
62 fct_shape = _shape_functions[op_type]
63 cache[key] = fct_shape
65 if fct_shape is None and rt_class is not None:
66 # check this operator is a predefined function in ONNX.
67 try:
68 onnx_schema = get_onnx_schema(node.op_type, node.domain)
69 except SchemaError:
70 onnx_schema = None
71 if onnx_schema is not None and onnx_schema.has_function:
72 sess = rt_class(onnx_schema.function_body)
73 if len(node.input) != len(sess.input_names):
74 raise RuntimeError( # pragma: no cover
75 "node and function must have the same number of inputs, "
76 "len(%r) != len(%r)." % (
77 node.input, sess.input_names))
78 if len(node.output) != len(sess.output_names):
79 raise RuntimeError( # pragma: no cover
80 "node and function must have the same number of outputs, "
81 "len(%r) != len(%r)." % (
82 node.output, sess.output_names))
84 def _shape_function(known_shape, node):
85 inputs = {iname: known_shape[name] for name, iname in
86 zip(node.input, sess.input_names)}
87 outputs = sess.run(inputs)
88 res = False
89 for name, oname in zip(node.output, sess.output_names):
90 r = known_shape.update(name, outputs[oname])
91 res = res or r
92 return res
94 fct_shape = _shape_function
95 cache[key] = fct_shape
97 if fct_shape is not None:
98 return fct_shape(known_shape, node)
100 raise RuntimeError( # pragma: no cover
101 "Unable to find a corresponding function for operator type %r "
102 "domain=%r, looking for %r among\n%s" % (
103 node.op_type, node.domain, "shape_" + node.op_type.lower(),
104 "\n".join(textwrap.wrap(
105 " ".join(_ for _ in sorted(_shape_functions))))))