Coverage for mlprodict/onnxrt/ops_shape/shape_result.py: 92%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Class ShapeResult
4"""
5from enum import Enum
6import numpy
7from .shape_excs import ShapeInferenceException
10class OnnxKind(Enum):
11 """
12 Describes a result type.
13 """
14 Tensor = 0
15 Sequence = 1
16 Map = 2
19class ShapeConstraint:
20 """
21 One constraint.
23 :param name: variable name
24 :param values: set of possible values
25 """
27 def __init__(self, name, values):
28 if name == '?':
29 raise ValueError( # pragma: no cover
30 "Name cannot be '?'.")
31 if not isinstance(values, set):
32 raise TypeError( # pragma: no cover
33 "values must be a set not %r." % type(values))
34 self.name = name
35 self.values = values
37 def __eq__(self, other):
38 "usual"
39 if self.name != other.name:
40 return False
41 if self.values != other.values:
42 return False
43 return True
45 def __repr__(self):
46 "usual"
47 return "%s(%r, %r)" % (
48 self.__class__.__name__, self.name, self.values)
50 def merge(self, cst):
51 """
52 Merges this constraint with *cst* into this one.
53 """
54 if isinstance(cst, list):
55 for c in cst:
56 self.merge(c)
57 return
58 self.values = self.values.intersection(cst.values)
60 def copy(self, deep=False):
61 """
62 Makes a copy of the object.
63 """
64 return ShapeConstraint(self.name, self.values.copy())
67class ShapeConstraintList:
68 """
69 A list of ShapeConstraint.
70 """
72 def __init__(self):
73 self.csts = []
75 def __contains__(self, cst):
76 for a in self.csts:
77 if cst == a:
78 return True
79 return False
81 def append(self, cst):
82 "Appends a new constraint to the list."
83 self.csts.append(cst)
85 def __repr__(self):
86 return "ShapeConstraintList(%r)" % self.csts
88 def __iter__(self):
89 for c in self.csts:
90 yield c
92 def __len__(self):
93 return len(self.csts)
95 def copy(self, deep=False):
96 """
97 Copies the object.
98 """
99 cp = ShapeConstraintList()
100 if deep:
101 cp.csts = [v.copy(deep=deep) for v in self]
102 else:
103 cp.csts = self.csts.copy()
104 return cp
107class ShapeResult:
108 """
109 Contains information about shape and type of a result
110 in an onnx graph.
112 :param name: result name
113 :param shape: shape if the result is a tensor
114 :param dtype: element type if the result is a tensor
115 :param sparse: is the tensor sparse
116 :param mtype: kind of the result (see class @see cl OnnxKind)
117 :param constraints: list of constraints applying on variables
118 """
120 def __init__(self, name, shape=None, dtype=None, sparse=False,
121 mtype=OnnxKind.Tensor, constraints=None):
122 if not isinstance(name, str):
123 raise TypeError( # pragma: no cover
124 "name must be a string not %r." % type(name))
125 if not isinstance(sparse, bool):
126 raise TypeError( # pragma: no cover
127 "sparse must be a boolean not %r." % sparse)
128 if not isinstance(mtype, OnnxKind):
129 raise TypeError( # pragma: no cover
130 "mtype must be of type OnnxKind not %r." % type(mtype))
131 self.shape = list(shape)
132 for i in range(0, len(self.shape)): # pylint: disable=C0200
133 if shape[i] in ('', None, '?'):
134 raise ValueError( # pragma: no cover
135 "All dimensions must an int or a variable name, "
136 "%s is not." % (shape, ))
137 self.name = name
138 self.mtype = mtype
139 self.dtype = dtype
140 self.sparse = sparse
141 if constraints is None:
142 self.constraints = ShapeConstraintList()
143 elif isinstance(constraints, ShapeConstraintList):
144 self.constraints = constraints
145 else:
146 raise TypeError( # pragma: no cover
147 "constraints must be of type(ShapeConstraintList).")
149 def is_compatible(self, shape):
150 """
151 Tells if this shape is compatible with the given tuple.
153 :param shape: tuple
154 :return: boolean
155 """
156 if isinstance(shape, numpy.ndarray):
157 shape = shape.shape
158 if all(map(lambda x: isinstance(x, int), self.shape)):
159 return tuple(self.shape) == tuple(shape)
160 raise NotImplementedError("%r ? %r" % (self, shape))
162 def copy(self, deep=False):
163 """
164 Returns a copy for the result.
165 """
166 return ShapeResult(self.name, self.shape, self.dtype, self.sparse,
167 self.mtype, self.constraints.copy(deep=deep))
169 def __repr__(self):
170 """
171 Usual
172 """
173 if len(self.constraints) > 0:
174 return "%s(%r, %r, %r, sparse=%r, mtype=%r, constraints=%r)" % (
175 self.__class__.__name__, self.name, self.shape, self.dtype,
176 self.sparse, self.mtype, self.constraints)
177 if self.mtype != OnnxKind.Tensor:
178 return "%s(%r, %r, %r, sparse=%r, mtype=%r)" % (
179 self.__class__.__name__, self.name, self.shape, self.dtype,
180 self.sparse, self.mtype)
181 if self.sparse:
182 return "%s(%r, %r, %r,sparse=%r)" % (
183 self.__class__.__name__, self.name, self.shape, self.dtype,
184 self.sparse)
185 return "%s(%r, %r, %r)" % (
186 self.__class__.__name__, self.name, self.shape, self.dtype)
188 def __eq__(self, shape):
189 """
190 Tells if two shapes are identical.
191 """
192 return (self.mtype == shape.mtype and self.shape == shape.shape and
193 self.dtype == shape.dtype and self.sparse == shape.sparse)
195 def n_dims(self):
196 """
197 Returns the number of dimensions if it is a tensor.
198 Raises an exception otherwise.
199 """
200 if self.mtype != OnnxKind.Tensor:
201 raise ShapeInferenceException( # pragma: no cover
202 "This shape is not a tensor %r." % self)
203 return len(self.shape)
205 def merge(self, other_result):
206 """
207 Merges constraints from *other_results* into *self*.
208 """
209 if self.mtype != other_result.mtype:
210 raise RuntimeError( # pragma: no cover
211 "Unable to merge %r and %r." % (self, other_result))
212 if (len(self.shape) != 0 and len(other_result.shape) != 0 and
213 len(self.shape) != len(other_result.shape)):
214 raise RuntimeError( # pragma: no cover
215 "Length mismatch, unable to merge %r and %r." % (
216 self, other_result))
217 updated = False
218 if other_result.constraints is not None:
219 for c in other_result.constraints:
220 if c not in self.constraints:
221 self.constraints.append(c)
222 updated = True
224 if len(self.shape) == 0 and len(other_result.shape) > 0:
225 # Then self.shape is unknown and the other one is.
226 self.shape = other_result.shape.copy()
227 return True
229 for a, b in zip(self.shape, other_result.shape):
230 if a == b:
231 continue
232 if isinstance(a, int) and isinstance(b, int):
233 raise RuntimeError(
234 "Inconsistancy between %r and %r." % (
235 self, other_result))
236 elif isinstance(a, str):
237 c = ShapeConstraint(a, {b})
238 if c not in self.constraints:
239 updated = True
240 self.constraints.append(c)
241 elif isinstance(b, str):
242 c = ShapeConstraint(b, {a})
243 if c not in self.constraints:
244 updated = True
245 self.constraints.append(c)
246 else:
247 raise NotImplementedError( # pragma: no cover
248 "Merge not implemented between %r and %r." % (
249 self, other_result))
250 return updated
252 def resolve(self, variables):
253 """
254 Results variables in a shape using values stored
255 in *variables*. It does not copy any constraints.
257 :param variables: dictionary `{ name: values }`
258 :return: new ShapeResult
259 """
260 res = ShapeResult(self.name, shape=self.shape, dtype=self.dtype,
261 sparse=self.sparse, mtype=self.mtype)
262 for i in range(len(res.shape)): # pylint: disable=C0200
263 v = res.shape[i]
264 if isinstance(v, str):
265 if v in variables:
266 vals = variables[v]
267 if vals is None:
268 # size unknown
269 continue
270 if len(vals) == 1:
271 res.shape[i] = list(vals)[0]
272 else:
273 res.shape[i] = set(vals)
274 else:
275 raise RuntimeError( # pragma: no cover
276 "Unable to resolve shape %r due to missing "
277 "%r." % (self, v))
278 return res
280 @staticmethod
281 def broadcast(sh1, sh2, name=None):
282 """
283 Broadcasts dimensions for an element wise operator.
285 :param sh1: ShapeResult
286 :param sh2: ShapeResult
287 :param name: name of the output ShapeResult
288 :return: ShapeResult
289 """
290 if not isinstance(sh1, ShapeResult):
291 raise TypeError( # pragma: no cover
292 "Unexpected type for sh1 %r." % type(sh1))
293 if not isinstance(sh2, ShapeResult):
294 raise TypeError( # pragma: no cover
295 "Unexpected type for sh2 %r." % type(sh2))
296 if sh1.mtype != OnnxKind.Tensor:
297 raise TypeError( # pragma: no cover
298 "sh1 must be a tensor not %r." % sh1.mtype)
299 if sh2.mtype != OnnxKind.Tensor:
300 raise TypeError( # pragma: no cover
301 "sh2 must be a tensor not %r." % sh2.mtype)
302 if sh1.n_dims() != sh2.n_dims():
303 if sh1.n_dims() == 1 and sh1.shape[0] == 1:
304 return ShapeResult(
305 name, sh2.shape, sh2.dtype, sh2.sparse, sh2.mtype)
306 if sh2.n_dims() == 1 and sh2.shape[0] == 1:
307 return ShapeResult(
308 name, sh1.shape, sh1.dtype, sh1.sparse, sh1.mtype)
309 raise ShapeInferenceException( # pragma: no cover
310 "Broadcasting is only implemented for shape of the same "
311 "size, shapes are %r and %r." % (sh1, sh2))
312 if sh1.dtype != sh2.dtype:
313 raise ShapeInferenceException( # pragma: no cover
314 "Cannot broadcast shapes %r and %r (dtypes)."
315 "" % (sh1, sh2))
317 constraints = ShapeConstraintList()
318 shape = []
319 for a, b in zip(sh1.shape, sh2.shape):
320 if isinstance(a, int) and isinstance(b, int):
321 if a != b:
322 if min(a, b) == 1:
323 d = max(a, b)
324 else:
325 raise ShapeInferenceException( # pragma: no cover
326 "Cannot broadcast shapes %r and %r (dimensions)."
327 "" % (sh1, sh2))
328 else:
329 d = a
330 elif isinstance(a, int):
331 if a != 1:
332 d = a
333 constraints.append(ShapeConstraint(b, {1, a}))
334 else:
335 d = b
336 elif isinstance(b, int):
337 if b != 1:
338 d = b
339 constraints.append(ShapeConstraint(a, {1, b}))
340 else:
341 d = a
342 elif a == b:
343 d = a
344 else:
345 raise ShapeInferenceException( # pragma: no cover
346 "Cannot broadcast shapes %r and %r." % (sh1, sh2))
347 shape.append(d)
348 if name in (None, ''):
349 raise ValueError( # pragma: no cover
350 "name cannot be empty.")
351 res = ShapeResult(name, shape, sh1.dtype, sh1.sparse or sh2.sparse,
352 sh1.mtype, constraints)
353 return res