Coverage for mlprodict/onnxrt/ops_shape/_element_wise.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Computes shape inference for element wise operators.
4"""
5from .shape_excs import ShapeInferenceException
6from .shape_result import ShapeResult, OnnxKind
9def _element_wise(known_shapes, node):
10 """
11 Infers shape for an element wise operator.
12 The function returns but updates *known_shapes*.
14 :param known_shapes: known shapes
15 :param node: Onnx node
16 :return: updated or not
17 """
18 x = known_shapes[node.input[0]]
19 y = known_shapes[node.input[1]]
20 if x.mtype != OnnxKind.Tensor:
21 raise ShapeInferenceException( # pragma: no cover
22 "Result %r must be a tensor." % x)
23 if y.mtype != OnnxKind.Tensor:
24 raise ShapeInferenceException( # pragma: no cover
25 "Result %r must be a tensor." % y)
26 return known_shapes.update(
27 node.output[0], ShapeResult.broadcast(x, y, name=node.output[0]))
30def shape_add(known_shapes, node):
31 "Infers shape for operator Add."
32 return _element_wise(known_shapes, node)
35def shape_and(known_shapes, node):
36 "Infers shape for operator And."
37 return _element_wise(known_shapes, node)
40def shape_div(known_shapes, node):
41 "Infers shape for operator Div."
42 return _element_wise(known_shapes, node)
45def shape_equal(known_shapes, node):
46 "Infers shape for operator Equal."
47 return _element_wise(known_shapes, node)
50def shape_greater(known_shapes, node):
51 "Infers shape for operator Greater."
52 return _element_wise(known_shapes, node)
55def shape_greaterorequal(known_shapes, node):
56 "Infers shape for operator GreaterOrEqual."
57 return _element_wise(known_shapes, node)
60def shape_less(known_shapes, node):
61 "Infers shape for operator Less."
62 return _element_wise(known_shapes, node)
65def shape_lessorequal(known_shapes, node):
66 "Infers shape for operator LessOrEqual."
67 return _element_wise(known_shapes, node)
70def shape_max(known_shapes, node):
71 "Infers shape for operator Max."
72 return _element_wise(known_shapes, node)
75def shape_min(known_shapes, node):
76 "Infers shape for operator Min."
77 return _element_wise(known_shapes, node)
80def shape_mod(known_shapes, node):
81 "Infers shape for operator Mod."
82 return _element_wise(known_shapes, node)
85def shape_mul(known_shapes, node):
86 "Infers shape for operator Mul."
87 return _element_wise(known_shapes, node)
90def shape_or(known_shapes, node):
91 "Infers shape for operator Or."
92 return _element_wise(known_shapes, node)
95def shape_pow(known_shapes, node):
96 "Infers shape for operator Pow."
97 return _element_wise(known_shapes, node)
100def shape_sub(known_shapes, node):
101 "Infers shape for operator Sub."
102 return _element_wise(known_shapes, node)
105def shape_xor(known_shapes, node):
106 "Infers shape for operator Xor."
107 return _element_wise(known_shapes, node)