Coverage for mlprodict/onnxrt/ops_shape/_element_unary.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Computes shape inference for element wise operators with one input.
4"""
5import numpy
6from .shape_excs import ShapeInferenceException
7from .shape_result import OnnxKind
10def _element_unary(known_shapes, node, dtype=None):
11 """
12 Infers shape for an element wise operator.
13 The function returns but updates *known_shapes*.
15 :param known_shapes: known shapes
16 :param node: Onnx node
17 :param dtype: None to keep the same type as input,
18 not None to change it
19 :return: updated or not
20 """
21 x = known_shapes[node.input[0]]
22 if x.mtype != OnnxKind.Tensor:
23 raise ShapeInferenceException( # pragma: no cover
24 "Result %r must be a tensor." % x)
25 if dtype is None:
26 return known_shapes.update(node.output[0], x.copy())
27 cp = x.copy()
28 cp.dtype = dtype
29 return known_shapes.update(node.output[0], cp)
32def shape_abs(known_shapes, node):
33 "Infers shape for operator Abs."
34 return _element_unary(known_shapes, node)
37def shape_acos(known_shapes, node):
38 "Infers shape for operator Acos."
39 return _element_unary(known_shapes, node)
42def shape_acosh(known_shapes, node):
43 "Infers shape for operator Acosh."
44 return _element_unary(known_shapes, node)
47def shape_asin(known_shapes, node):
48 "Infers shape for operator Asin."
49 return _element_unary(known_shapes, node)
52def shape_asinh(known_shapes, node):
53 "Infers shape for operator Asinh."
54 return _element_unary(known_shapes, node)
57def shape_atan(known_shapes, node):
58 "Infers shape for operator Atan."
59 return _element_unary(known_shapes, node)
62def shape_atanh(known_shapes, node):
63 "Infers shape for operator Atanh."
64 return _element_unary(known_shapes, node)
67def shape_castlike(known_shapes, node):
68 "Infers shape for operator CastLike."
69 x = known_shapes[node.input[0]]
70 if x.mtype != OnnxKind.Tensor:
71 raise ShapeInferenceException( # pragma: no cover
72 "Result %r must be a tensor." % x)
73 y = known_shapes[node.input[1]]
74 if y.mtype != OnnxKind.Tensor:
75 raise ShapeInferenceException( # pragma: no cover
76 "Result %r must be a tensor." % y)
77 cp = x.copy()
78 cp.dtype = y.dtype
79 return known_shapes.update(node.output[0], cp)
82def shape_ceil(known_shapes, node):
83 "Infers shape for operator Ceil."
84 return _element_unary(known_shapes, node)
87def shape_celu(known_shapes, node):
88 "Infers shape for operator Celu."
89 return _element_unary(known_shapes, node)
92def shape_clip(known_shapes, node):
93 "Infers shape for operator Clip."
94 return _element_unary(known_shapes, node)
97def shape_cos(known_shapes, node):
98 "Infers shape for operator Cos."
99 return _element_unary(known_shapes, node)
102def shape_cosh(known_shapes, node):
103 "Infers shape for operator Cosh."
104 return _element_unary(known_shapes, node)
107def shape_elu(known_shapes, node):
108 "Infers shape for operator Elu."
109 return _element_unary(known_shapes, node)
112def shape_erf(known_shapes, node):
113 "Infers shape for operator Erf."
114 return _element_unary(known_shapes, node)
117def shape_exp(known_shapes, node):
118 "Infers shape for operator Exp."
119 return _element_unary(known_shapes, node)
122def shape_floor(known_shapes, node):
123 "Infers shape for operator Floor."
124 return _element_unary(known_shapes, node)
127def shape_hardmax(known_shapes, node):
128 "Infers shape for operator Hardmax."
129 return _element_unary(known_shapes, node)
132def shape_hardsigmoid(known_shapes, node):
133 "Infers shape for operator HardSigmoid."
134 return _element_unary(known_shapes, node)
137def shape_identity(known_shapes, node):
138 "Infers shape for operator Identity."
139 return _element_unary(known_shapes, node)
142def shape_isnan(known_shapes, node):
143 "Infers shape for operator IsNan."
144 return _element_unary(known_shapes, node, numpy.bool_)
147def shape_isinf(known_shapes, node):
148 "Infers shape for operator IsInf."
149 return _element_unary(known_shapes, node, numpy.bool_)
152def shape_leakyrelu(known_shapes, node):
153 "Infers shape for operator LeakyRelu."
154 return _element_unary(known_shapes, node)
157def shape_log(known_shapes, node):
158 "Infers shape for operator Log."
159 return _element_unary(known_shapes, node)
162def shape_logsoftmax(known_shapes, node):
163 "Infers shape for operator LogSoftmax."
164 return shape_softmax(known_shapes, node)
167def shape_neg(known_shapes, node):
168 "Infers shape for operator Neg."
169 return _element_unary(known_shapes, node)
172def shape_not(known_shapes, node):
173 "Infers shape for operator Not."
174 x = known_shapes[node.input[0]]
175 if x.dtype != numpy.bool_:
176 raise ShapeInferenceException(
177 "Unexpected input type for operator Not %r (must be bool)."
178 "" % x.dtype)
179 return _element_unary(known_shapes, node)
182def shape_reciprocal(known_shapes, node):
183 "Infers shape for operator Reciprocal."
184 return _element_unary(known_shapes, node)
187def shape_relu(known_shapes, node):
188 "Infers shape for operator Relu."
189 return _element_unary(known_shapes, node)
192def shape_round(known_shapes, node):
193 "Infers shape for operator Round."
194 return _element_unary(known_shapes, node)
197def shape_selu(known_shapes, node):
198 "Infers shape for operator Selu."
199 return _element_unary(known_shapes, node)
202def shape_sigmoid(known_shapes, node):
203 "Infers shape for operator Sigmoid."
204 return _element_unary(known_shapes, node)
207def shape_sign(known_shapes, node):
208 "Infers shape for operator Sigmoid."
209 return _element_unary(known_shapes, node)
212def shape_sin(known_shapes, node):
213 "Infers shape for operator Sin."
214 return _element_unary(known_shapes, node)
217def shape_sinh(known_shapes, node):
218 "Infers shape for operator Sinh."
219 return _element_unary(known_shapes, node)
222def shape_softmax(known_shapes, node):
223 "Infers shape for operator Softmax."
224 return _element_unary(known_shapes, node)
227def shape_sqrt(known_shapes, node):
228 "Infers shape for operator Sqrt."
229 return _element_unary(known_shapes, node)
232def shape_tan(known_shapes, node):
233 "Infers shape for operator Tan."
234 return _element_unary(known_shapes, node)
237def shape_tanh(known_shapes, node):
238 "Infers shape for operator Tanh."
239 return _element_unary(known_shapes, node)
242def shape_trilu(known_shapes, node):
243 "Infers shape for operator Trilu."
244 return _element_unary(known_shapes, node)