Coverage for mlprodict/onnxrt/shape_object.py: 91%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=C0302
2"""
3@file
4@brief Shape object.
5"""
6import numpy
9class BaseDimensionShape:
10 """
11 Base class to @see cl DimensionObject,
12 @see cl ShapeOperator, @see cl ShapeObject.
13 """
15 def to_string(self, use_x=True):
16 """
17 Converts the object into a string.
18 """
19 raise NotImplementedError()
21 def evaluate(self, **kwargs):
22 """
23 Evaluates the object, reduces the expression
24 to a number or a string.
25 """
26 raise NotImplementedError() # pragma: no cover
29class ShapeOperator(BaseDimensionShape):
30 """
31 Base class for all shapes operator.
32 """
34 def __init__(self, name, fct, fct_string, *args):
35 """
36 @param name display name of the operator
37 @param fct function doing the operator
38 if argument are numeric
39 @param fct_string function represented as a string
40 @param args argument of the operator
41 """
42 self._name = name
43 self._fct = fct
44 self._fct_string = fct_string
45 self._args = args
46 for a in self._args:
47 if not isinstance(a, DimensionObject):
48 raise TypeError(
49 "All arguments must be of type DimensionObject not '{}'."
50 "".format(type(a)))
52 def __repr__(self):
53 """
54 usual
55 """
56 return "{0}('{1}', {2}, '{2}', {3})".format(
57 self.__class__.__name__, self._name,
58 self._fct_string, self._args)
60 def to_string(self, use_x=True):
61 """
62 Displays as a string.
64 @return a string
65 """
66 raise NotImplementedError( # pragma: no cover
67 "Operator '{}' does not implement 'to_string': {}.".format(
68 self.__class__.__name__, repr(self)))
70 def evaluate(self, **kwargs):
71 """
72 Evalutes the operator.
74 @param kwargs value for the variables.
75 @return string or integer
76 """
77 args = []
78 has_string = False
79 for a in self._args:
80 a = DimensionObject._same_(a)
81 v = a.evaluate(**kwargs)
82 if isinstance(v, str):
83 has_string = True
84 args.append(v)
85 if has_string:
86 res = self._evaluate_string_(args, **kwargs)
87 else:
88 try:
89 res = self._fct(*args)
90 except TypeError as e:
91 raise RuntimeError(
92 "Unable to evaluate operator {} due to {}".format(repr(self), e)) from e
93 return res
95 def _evaluate_string_(self, args, **kwargs):
96 """
97 Evalutes the operator assuming some of them are still strings.
99 @param args arguments extracted by method *evaluate*
100 @param kwargs value for the variables.
101 @return string or integer
102 """
103 raise NotImplementedError(
104 "This function must be overwritten.") # pragma: no cover
107class ShapeBinaryOperator(ShapeOperator):
108 """
109 Base class for shape binary operator.
110 """
112 def __init__(self, name, fct, fct_string, x, y):
113 """
114 @param name display name of the operator
115 @param fct function doing the operator
116 if argument are numeric
117 @param fct_string function represented as a string
118 @param x first argument
119 @param y second argument
120 """
121 ShapeOperator.__init__(self, name, fct, fct_string, x, y)
122 if isinstance(x, tuple):
123 raise TypeError('x cannot be a tuple') # pragma: no cover
124 if isinstance(y, tuple):
125 raise TypeError('y cannot be a tuple') # pragma: no cover
127 def _to_string1(self, x, y):
128 return DimensionObject(self._fct(x._dim, y._dim)).to_string()
130 def _to_string2(self, x, y):
131 return DimensionObject("{}{}{}".format(x._dim, self._name, y._dim)).to_string()
133 def _to_string2b(self, x, y):
134 return DimensionObject("({}){}({})".format(x._dim, self._name, y._dim)).to_string()
136 def _to_string3(self, x):
137 return DimensionObject("{}{}x".format(x._dim, self._name)).to_string()
139 def to_string(self, use_x=True):
140 """
141 Applies binary operator to a dimension.
143 @param use_x use `'x'` if dimension is unknown
144 @return a string
145 """
146 x, y = self._args # pylint: disable=W0632
147 if isinstance(x._dim, int):
148 if isinstance(y, DimensionObject):
149 if isinstance(y._dim, int):
150 return self._to_string1(x, y)
151 if isinstance(y._dim, str):
152 return self._to_string2(x, y)
153 if y._dim is None:
154 if use_x:
155 return self._to_string3(x)
156 return DimensionObject("{}{}DimensionObject()".format(
157 x._dim, self._name)).to_string()
158 raise TypeError( # pragma: no cover
159 "Unable to handle type '{}'.".format(type(y._dim)))
160 raise TypeError( # pragma: no cover
161 "Unable to handle type '{}'.".format(type(y)))
162 elif isinstance(x._dim, str):
163 if isinstance(y._dim, int):
164 return self._to_string2(x, y)
165 if isinstance(y._dim, str):
166 return self._to_string2b(x, y)
167 raise TypeError( # pragma: no cover
168 "Unable to handle type '{}'.".format(type(y._dim)))
169 raise TypeError( # pragma: no cover
170 "Unable to handle type '{}'.".format(type(x._dim)))
172 def _evaluate_string_(self, args, **kwargs):
173 """
174 Evalutes the operator assuming some of them are still strings.
176 @param args arguments extracted by method *evaluate*
177 @param kwargs value for the variables.
178 @return string or integer
179 """
180 return self._name.join(map(lambda s: '({})'.format(s), args))
183class ShapeBinaryFctOperator(ShapeBinaryOperator):
184 """
185 Base class for shape binary operator defined by a function.
186 """
188 def _to_string2(self, x, y):
189 return DimensionObject("{}({},{})".format(self._name, x._dim, y._dim)).to_string()
191 def _to_string2b(self, x, y):
192 return DimensionObject("{}({},{})".format(self._name, x._dim, y._dim)).to_string()
194 def _to_string3(self, x):
195 return DimensionObject("{}({},x)".format(self._name, x._dim)).to_string()
197 def _evaluate_string_(self, args, **kwargs):
198 """
199 Evalutes the operator assuming some of them are still strings.
201 @param args arguments extracted by method *evaluate*
202 @param kwargs value for the variables.
203 @return string or integer
204 """
205 return "{}({})".format(self._name, ",".join(map(str, args)))
208class ShapeOperatorAdd(ShapeBinaryOperator):
209 """
210 Shape addition.
211 """
213 def __init__(self, x, y):
214 ShapeBinaryOperator.__init__(
215 self, '+', lambda a, b: a + b, 'lambda a, b: a + b', x, y)
217 def __repr__(self):
218 """
219 Displays a string.
221 @return a string
222 """
223 return "{0}({1}, {2})".format(
224 self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
227class ShapeOperatorMul(ShapeBinaryOperator):
228 """
229 Shape multiplication.
230 """
232 def __init__(self, x, y):
233 ShapeBinaryOperator.__init__(
234 self, '*', lambda a, b: a * b, 'lambda a, b: a * b', x, y)
236 def __repr__(self):
237 """
238 Displays a string.
240 @return a string
241 """
242 return "{0}({1}, {2})".format(
243 self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
246class ShapeOperatorGreater(ShapeBinaryOperator):
247 """
248 Shape comparison.
249 """
251 def __init__(self, x, y):
252 ShapeBinaryOperator.__init__(
253 self, '>', lambda a, b: a > b, 'lambda a, b: a > b', x, y)
255 def __repr__(self):
256 """
257 Displays a string.
259 @return a string
260 """
261 return "{0}({1}, {2})".format(
262 self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
265class ShapeOperatorMax(ShapeBinaryFctOperator):
266 """
267 Best on each dimension.
268 """
270 def __init__(self, x, y):
271 ShapeBinaryFctOperator.__init__(
272 self, 'max', lambda a, b: max(a, b), 'max(a, b)', x, y)
274 def __repr__(self):
275 """
276 Displays a string.
278 @return a string
279 """
280 return "{0}({1}, {2})".format(
281 self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
284class DimensionObject(BaseDimensionShape):
285 """
286 One dimension of a shape.
287 """
289 def __init__(self, obj):
290 """
291 @param obj int or @see cl DimensionObject or None to
292 specify something unknown
293 """
294 if obj is None or obj == 0 or obj == '?':
295 self._dim = None
296 elif isinstance(obj, (int, str, ShapeOperator, DimensionObject,
297 numpy.int32, numpy.int64)):
298 self._dim = obj
299 else:
300 raise TypeError("Unexpected type for obj: {}".format(type(obj)))
302 @property
303 def dim(self):
304 """
305 Returns the dimension.
306 """
307 return self._dim
309 def __repr__(self):
310 """
311 usual
312 """
313 if isinstance(self._dim, int):
314 return "DimensionObject({})".format(self._dim)
315 if isinstance(self._dim, DimensionObject):
316 return repr(self._dim)
317 if isinstance(self._dim, ShapeOperator):
318 return "DimensionObject({})".format(repr(self._dim))
319 return "DimensionObject('{}')".format(self._dim)
321 @staticmethod
322 def _same_(obj):
323 """
324 Returns *obj* if *obj* is @see cl DimensionObject
325 otherwise converts it.
326 """
327 if isinstance(obj, DimensionObject):
328 return obj
329 return DimensionObject(obj)
331 def to_string(self, use_x=True):
332 """
333 Represents the dimension as a string.
334 """
335 if isinstance(self._dim, int):
336 return '{}'.format(self._dim)
337 if isinstance(self._dim, ShapeOperator):
338 return self._dim.to_string()
339 if isinstance(self._dim, str):
340 return self._dim
341 if self._dim is None:
342 return 'x' if use_x else '?'
343 raise NotImplementedError( # pragma: no cover
344 "Not implemented for '{}'.".format(repr(self)))
346 def evaluate(self, **kwargs):
347 """
348 Evalutes the dimension.
350 @param kwargs value for the variables.
351 @return string or integer
352 """
353 if isinstance(self._dim, (int, ShapeOperator, DimensionObject)):
354 res = self._dim
355 elif isinstance(self._dim, str):
356 if self._dim in kwargs:
357 res = kwargs[self._dim]
358 else:
359 res = self._dim
360 elif self._dim is None:
361 pref = str(hex(id(self)))[2:]
362 res = "n{}".format(pref)
363 elif isinstance(self._dim, ):
364 res = self._dim.evaluate(**kwargs)
365 else:
366 raise NotImplementedError( # pragma: no cover
367 "Not implemented for '{}'.".format(repr(self)))
368 if isinstance(res, (ShapeOperator, DimensionObject)):
369 return res.evaluate(**kwargs)
370 return res
372 def __eq__(self, v):
373 """
374 usual
375 """
376 if isinstance(v, (int, str)):
377 return self._dim == v
378 if isinstance(v, DimensionObject):
379 return v == self._dim
380 if isinstance(v, ShapeOperator):
381 ve = v.evaluate()
382 return ve == self._dim
383 if v is None:
384 return self._dim is None
385 raise TypeError( # pragma: no cover
386 "Unable to compare a DimensionObject to {}".format(type(v)))
388 def __add__(self, obj):
389 """
390 usual
391 """
392 return DimensionObject(
393 ShapeOperatorAdd(self, DimensionObject._same_(obj)))
395 def __mul__(self, obj):
396 """
397 usual
398 """
399 return DimensionObject(
400 ShapeOperatorMul(self, DimensionObject._same_(obj)))
402 def __gt__(self, obj):
403 """
404 usual
405 """
406 if obj is None:
407 return not isinstance(self._dim, int)
408 if isinstance(self._dim, int) and isinstance(obj._dim, int):
409 return self._dim > obj._dim
410 return DimensionObject(
411 ShapeOperatorGreater(self, DimensionObject._same_(obj)))
414class ShapeObject(BaseDimensionShape):
415 """
416 Handles mathematical operations around shapes.
417 It stores a type (:epkg:`numpy` type),
418 and a name to somehow have an idea of where
419 the shape comes from in the :epkg:`ONNX` graph.
420 The shape itself is defined by a list of
421 @see cl DimensionObject or @see cl ShapeOperator
422 or *None* if the shape is unknown. A dimension is an
423 integer or a variable encoded as a string. This variable
424 is a way to tell the dimension may vary.
426 .. runpython::
427 :showcode:
428 :warningout: DeprecationWarning
430 import numpy
431 from mlprodict.onnxrt.shape_object import ShapeObject
433 sh1 = ShapeObject((1, 2), dtype=numpy.float32)
434 sh2 = ShapeObject((45, 2), dtype=numpy.float32)
435 mx = max(sh1, sh2)
436 print(mx)
438 sh1 = ShapeObject((1, 2), dtype=numpy.float32)
439 sh2 = ShapeObject((None, 2), dtype=numpy.float32)
440 print(sh2)
441 mx = max(sh1, sh2)
442 print(mx.to_string())
444 sh1 = ShapeObject((1, 2), dtype=numpy.float32)
445 sh2 = ShapeObject(('n', 2), dtype=numpy.float32)
446 print(sh2)
447 mx = max(sh1, sh2)
448 print(mx.evaluate(n=4))
449 """
451 def __init__(self, shape, dtype=None, use_n1=False, name=None,
452 subtype=None):
453 """
454 @param shape tuple or `numpy.array`
455 @param dtype dtype
456 @param use_n1 use `'n'` if the first dimension is unknown
457 @param name optional, for debugging purposes
458 @param subtype element type if this type is a list
459 """
460 self.name = name
461 self.subtype = subtype
462 if isinstance(shape, numpy.ndarray):
463 self._shape = [DimensionObject(s) for s in shape.shape]
464 self._dtype = shape.dtype
465 elif isinstance(shape, dict) and 'type' in shape:
466 tshape = shape['type']
467 if tshape['kind'] == 'tensor':
468 if tshape['shape'] == ('?', ):
469 self._shape = None
470 else:
471 self._shape = [DimensionObject(s) for s in tshape['shape']]
472 self._dtype = tshape['elem']
473 elif tshape['kind'] == 'map':
474 self._shape = []
475 self._dtype = 'map'
476 elif tshape['kind'] == 'sequence':
477 self._shape = []
478 self._dtype = 'sequence'
479 else:
480 raise ValueError( # pragma: no cover
481 "Wrong shape value {}".format(shape))
482 elif isinstance(shape, (tuple, list)):
483 self._shape = []
484 for s in shape:
485 self._shape.append(DimensionObject(s))
486 self._dtype = dtype
487 elif shape is None:
488 # shape is unknown
489 self._shape = None
490 self._dtype = dtype
491 else:
492 raise TypeError( # pragma: no cover
493 "Unexpected type for shape: {}, shape={}".format(
494 type(shape), shape))
496 def _dtype_again():
497 if self._dtype is None:
498 raise TypeError(
499 "dtype cannot be None, shape type is {}\n{}".format(
500 type(shape), shape))
501 if isinstance(self._dtype, numpy.dtype):
502 # no need to go further
503 return
504 if self._dtype in (float, 'double', 'tensor(double)'):
505 self._dtype = numpy.float64
506 elif self._dtype in ('float32', 'float', 'tensor(float)'):
507 self._dtype = numpy.float32
508 elif self._dtype in (numpy.float16, 'float16', 'tensor(float16)'):
509 self._dtype = numpy.float16
510 elif self._dtype in ('int32', 'tensor(int32)'):
511 self._dtype = numpy.int32
512 elif self._dtype in (int, 'int', 'int64', 'tensor(int64)'):
513 self._dtype = numpy.int64
514 elif self._dtype in (str, 'str', numpy.str_, 'tensor(str)'):
515 self._dtype = numpy.str_
516 elif (hasattr(self._dtype, 'type') and self._dtype.type is numpy.string_):
517 pass
518 elif self._dtype in (bool, 'bool', numpy.bool_):
519 self._dtype = numpy.bool_
520 elif self._dtype in (object, numpy.object_):
521 pass
522 elif self._dtype in (numpy.int8, 'int8', ):
523 self._dtype = numpy.int8
524 elif self._dtype in (numpy.uint8, 'uint8', ):
525 self._dtype = numpy.uint8
526 elif self._dtype in (numpy.int16, 'int16', ):
527 self._dtype = numpy.int16
528 elif self._dtype in (numpy.uint16, 'uint16', ):
529 self._dtype = numpy.uint16
530 elif self._dtype in (numpy.uint32, 'uint32', ):
531 self._dtype = numpy.uint32
532 elif self._dtype in (numpy.uint64, 'uint64', ):
533 self._dtype = numpy.uint64
534 elif self._dtype in (numpy.complex64, 'complex64', ):
535 self._dtype = numpy.complex64
536 elif self._dtype in (numpy.complex128, 'complex128', ):
537 self._dtype = numpy.complex128
538 elif self._dtype == "tensor({'kind': 'tensor', 'elem': 'float', 'shape': })":
539 self._dtype = numpy.float32
540 elif self._dtype not in {
541 numpy.float32, numpy.float64, numpy.int32, numpy.int64,
542 numpy.str_, numpy.bool_, numpy.float16, None,
543 numpy.complex64, numpy.complex128,
544 'map', 'sequence'}:
545 raise TypeError( # pragma: no cover
546 "dtype has an unexpected value: '{}'.".format(self._dtype))
547 try:
548 _dtype_again()
549 except TypeError as e:
550 raise TypeError( # pragma: no cover
551 "Unexpected error with %r of type %r, name=%r." % (
552 (self._dtype, type(self._dtype), name))) from e
554 def _shape_again():
555 if self._shape is not None:
556 for i, a in enumerate(self._shape):
557 if not isinstance(a, DimensionObject):
558 raise TypeError( # pragma: no cover
559 'Dimension {} has a wrong type {}'.format(
560 i, type(a)))
561 if use_n1:
562 sh = self._shape[0] if self._shape else None
563 if isinstance(sh, DimensionObject) and sh._dim is None:
564 sh._dim = 'n'
565 if self._shape is not None:
566 for s in self._shape:
567 if isinstance(s, int):
568 raise TypeError( # pragma: no cover
569 "Unexpected type int in shape %r." % self)
570 _shape_again()
572 def reshape(self, shape):
573 """
574 Creates a new shape, checks the number of elements is the same.
575 """
576 sh = ShapeObject(shape, self.dtype, getattr(self, '_dim', None),
577 self.name)
578 p1 = self.product().evaluate()
579 p2 = sh.product().evaluate()
580 if isinstance(p1, int) and p1 != p2:
581 raise ValueError("Shape {} cannot be reshaped into {} "
582 "(p1={}, p2={}).".format(sh, shape, p1, p2))
583 return sh
585 def copy(self, dtype=None, name=None):
586 """
587 A copy not a deepcopy.
589 @param dtype None or a value to rewrite the type.
590 @param name overwrites the name
591 @return @see cl ShapeObject
592 """
593 if self._shape is None:
594 return ShapeObject(None, dtype=self.dtype, name=name or self.name)
595 return ShapeObject(self._shape.copy(),
596 self.dtype if dtype is None else dtype,
597 name=name or self.name,
598 subtype=self.subtype)
600 def __getitem__(self, index):
601 """
602 Extracts a specific dimension.
603 """
604 if self._shape is None:
605 return None
606 if isinstance(index, int) and index >= len(self._shape):
607 return 1
608 return self._shape[index]
610 def __setitem__(self, index, value):
611 """
612 Changes a specific dimension.
613 """
614 if self._shape is None:
615 return
616 while len(self._shape) <= index:
617 self._shape.append(DimensionObject(1))
618 self._shape[index] = value
620 @property
621 def shape(self):
622 """
623 Returns the stored shape.
624 """
625 if self._shape is None:
626 return None
627 return tuple(self._shape)
629 def __len__(self):
630 """
631 Returns the number of dimensions.
632 """
633 if self._shape is None:
634 return 0
635 return len(self._shape)
637 @property
638 def dtype(self):
639 """
640 Returns the stored *dtype*.
641 """
642 return self._dtype
644 def reduce(self, axis=1, keepdims=False, dtype=None):
645 """
646 Reduces the matrix. Removes one dimension.
648 @param axis axis
649 @param keepdims keep dimensions, replaces the removed
650 dimension by 1
651 @param dtype if not None, changes the type
652 @return new dimension
653 """
654 if self._shape is None:
655 if self.name is None:
656 return self.copy()
657 return self.copy(name="{}-RD".format(self.name))
658 if axis is None:
659 return ShapeObject((1, ), self._dtype if dtype is None else dtype,
660 name="{}-RDN".format(self.name))
662 if isinstance(axis, ShapeObject):
664 def drop_axis(shape, a):
665 c = list(shape)
666 del c[a[0]]
667 return c
669 return ShapeObjectFct(
670 drop_axis, self, axis, name="DropAxis", dtype=self.dtype)
672 if axis < 0:
673 axis = len(self._shape) + axis
674 if 0 <= axis < len(self._shape):
675 cp = self._shape.copy()
676 if keepdims:
677 cp[axis] = DimensionObject(1)
678 else:
679 del cp[axis]
680 return ShapeObject(cp, self._dtype if dtype is None else dtype,
681 name="{}-RD".format(self.name))
682 raise IndexError("axis={} is wrong, shape is {}-tuple and equal to "
683 "{}".format(axis, len(self._shape), self))
685 def __repr__(self):
686 """
687 usual
688 """
689 st = str(self.dtype)
690 if "'" in st:
691 st = st.split("'")[1]
693 if self.shape is None:
694 if self.name is None:
695 return "ShapeObject(None, dtype={})".format(st)
696 return "ShapeObject(None, dtype={}, name='{}')".format(st, self.name)
698 st_shape = []
699 for s in self.shape:
700 if isinstance(getattr(s, "_dim", None), (int, str)):
701 st_shape.append(str(s._dim))
702 else:
703 st_shape.append(repr(s))
704 if len(st_shape) == 1:
705 st_shape.append('')
706 st_shape = '({})'.format(", ".join(st_shape))
707 if self.name is None:
708 return "ShapeObject({}, dtype={})".format(st_shape, st)
709 return "ShapeObject({}, dtype={}, name='{}')".format(
710 st_shape, st, self.name)
712 def __iter__(self):
713 """
714 Iterators over dimensions.
715 """
716 if self._shape is not None:
717 for d in self._shape:
718 yield d
720 def __gt__(self, a):
721 """
722 Compares shapes. Operator ``>``.
723 """
724 if isinstance(a, tuple):
725 a = ShapeObject(a, dtype=self._dtype)
726 if self._shape is None and a._shape is None:
727 return False
728 if self._shape is None:
729 return True
730 if a._shape is None:
731 return False
732 if len(self) > len(a):
733 return True
734 if len(self) < len(a):
735 return False
736 for d1, d2 in zip(self, a):
737 if d1 > d2:
738 return True
739 if d1 < d2:
740 return False
741 return False
743 def __eq__(self, a):
744 """
745 Tests equality between two shapes.
746 """
747 if isinstance(a, tuple):
748 a = ShapeObject(a, dtype=self._dtype)
749 if self._shape is None and a._shape is None:
750 return True
751 if self._shape is None or a._shape is None:
752 return False
753 if len(self) != len(a):
754 return False
755 for d1, d2 in zip(self, a):
756 if d1 == d2:
757 continue
758 return False
759 return True
761 def evaluate(self, **kwargs):
762 """
763 Evaluates the shape.
764 """
765 vs = []
766 for v in self:
767 d = v.evaluate(**kwargs)
768 vs.append(d)
769 return ShapeObject(tuple(vs), self._dtype, name="{}-EV".format(self.name))
771 def to_string(self, use_x=False):
772 """
773 Converts shapes into a string.
774 """
775 shapes = []
776 for a in self._shape:
777 shapes.append(a.to_string(use_x=use_x))
778 return '({})'.format(', '.join(shapes))
780 def product(self):
781 """
782 Multiplies all the dimension.
784 @return @see cl DimensionObject
785 """
786 cl = self[0]
787 for i in range(1, len(self)):
788 cl = cl * self[i]
789 return cl
791 def append(self, dim):
792 """
793 Appends a dimension.
794 """
795 if self._shape is None:
796 return
797 if isinstance(dim, DimensionObject):
798 self._shape.append(dim)
799 else:
800 self._shape.append(DimensionObject(dim))
802 def insert(self, dim, pos=0):
803 """
804 Inserts a dimension at position *pos*.
805 """
806 if self._shape is None:
807 return
808 if isinstance(dim, DimensionObject):
809 self._shape.insert(pos, dim)
810 else:
811 self._shape.insert(pos, DimensionObject(dim))
813 def squeeze(self, axis):
814 """
815 Removes one dimension.
816 """
817 cp = self.copy(name='{}-SZ'.format(self.name))
818 cp.drop_axis(axis)
819 return cp
821 def unsqueeze(self, axes):
822 """
823 Adds dimensions.
824 """
825 cp = self
826 name = '{}-USZ'.format(self.name)
827 for ax in axes[::-1]:
828 cp = cp.copy(name=name)
829 cp.insert(ax, 1)
830 return cp
832 def transpose(self, perm):
833 """
834 Removes one dimension.
835 """
836 if self.shape is None:
837 return self.copy(name='{}-TR'.format(self.name))
838 cp = ShapeObject([None for p in perm], dtype=self.dtype,
839 name="{}-TR".format(self.name))
840 for i, p in enumerate(perm):
841 if p >= len(self):
842 # This should not happen.
843 cp._shape[i] = None
844 else:
845 cp._shape[i] = self._shape[p]
846 return cp
848 def drop_axis(self, axis):
849 """
850 Drops an axis.
851 """
852 if self._shape is not None:
853 if isinstance(axis, (tuple, list)):
854 for i in sorted(axis, reverse=True):
855 del self._shape[i]
856 else:
857 del self._shape[axis]
859 def broadcast(self, a):
860 """
861 Computes the shape after a broadcast.
862 """
863 if a is None:
864 raise ValueError("a should not be None") # pragma: no cover
865 if a._shape is None:
866 return a.copy()
867 if self._shape is None:
868 return self.copy()
869 mx = max(len(self._shape), len(a._shape))
870 res = []
871 for i in range(mx):
872 if i < len(self._shape):
873 if i < len(a._shape):
874 res.append(ShapeOperatorMax(self[i], a[i]))
875 else:
876 res.append(self[i])
877 else:
878 res.append(a[i])
879 return ShapeObject(tuple(res), self.dtype, False,
880 name="broadcast-{}-{}".format(self.name, a.name))
882 @staticmethod
883 def _infer_merged_type(*args, use_dtype=True):
884 if use_dtype:
885 tys = set(a.dtype for a in args)
886 else:
887 tys = set(args)
888 if len(tys) == 1:
889 return list(tys)[0]
890 if any(tys & {numpy.float64, numpy.int64,
891 numpy.float32, numpy.int32,
892 numpy.float16}):
893 return numpy.float64
894 raise RuntimeError( # pragma: no cover
895 "Unable to infer types based on {} ({}).".format(
896 tys, len(tys)))
898 def concat_columns(self, axis, *shapes):
899 """
900 Concatenates columns from *shapes* to this one
901 along one axis.
902 """
903 args = [self] + list(shapes)
904 dtype = self._infer_merged_type(*args)
905 dim_axis = self[axis]
906 if isinstance(dim_axis, int):
907 dim_axis = DimensionObject(dim_axis)
908 if dim_axis is None:
909 return ShapeObject(None, dtype=dtype)
910 if isinstance(dim_axis, int):
911 raise TypeError( # pragma: no cover
912 "Unexpected type for shape %r." % self)
913 for a in shapes:
914 if a[axis] is None:
915 return ShapeObject(None, dtype=dtype)
916 dim_axis = dim_axis + a[axis]
917 a0 = args[0].copy(dtype=dtype)
918 a0[axis] = dim_axis
919 return a0
921 @staticmethod
922 def einsum_shape(equation, *inputs):
923 """
924 Computes :epkg:`einsum` shapes.
925 Not the most efficient one as it creates variables
926 of the given shapes.
927 """
928 for inp in inputs:
929 if inp.shape is None:
930 return inp
931 if b"->" not in equation:
932 raise RuntimeError( # pragma: no cover
933 "Equation %r does not have '->'.")
934 inp, out = [_.strip() for _ in equation.split(b"->")]
935 inps = [_.strip() for _ in inp.split(b',')]
936 if len(inputs) != len(inps):
937 raise RuntimeError( # pragma: no cover
938 "Input mismatch between '{}' and {}.".format(equation, inps))
939 shs = {}
940 for a, b in zip(inps, inputs):
941 if len(a) != len(b):
942 raise RuntimeError( # pragma: no cover
943 "Input mismatch '{}' (in '{}') and {}.".format(a, equation, b))
944 for c, s in zip(a, b):
945 if c not in shs:
946 shs[c] = s
947 elif shs[c] != s:
948 raise RuntimeError( # pragma: no cover
949 "Equation '{}'. Dimension mismatch '{}' != {}.".format(
950 equation, s, shs[c]))
951 new_shape = [shs[i] for i in out]
952 return ShapeObject(new_shape, dtype=ShapeObject._infer_merged_type(*inputs))
954 @staticmethod
955 def gather_shape(input, indices, axis):
956 """
957 Computes Gather shapes.
958 """
959 input_rank = len(input)
960 if input_rank is None:
961 return ShapeObject(None, dtype=input._dtype)
962 index_rank = len(indices)
963 if index_rank is None:
964 return ShapeObject(None, dtype=input._dtype)
966 if axis < 0:
967 axis = input_rank + axis
969 shape = []
970 for i in range(axis):
971 shape.append(input[i])
973 for dim in indices:
974 shape.append(dim)
976 for i in range(axis + 1, input_rank):
977 shape.append(input[i])
979 return ShapeObject(shape, dtype=input._dtype)
982class ShapeObjectFct(ShapeObject):
983 """
984 Computes a shape depending on a user defined function.
985 See @see cl Conv for an example.
986 """
988 def __init__(self, fct, *shapes, dtype=None, name=None):
989 """
990 @param fct function
991 @param shapes shapes sent to fct
992 @param dtype dtype
993 @param name optional, for debugging purposes
994 """
995 ShapeObject.__init__(self, None, dtype=dtype, name=name)
996 self._fct = fct
997 self._shapes = shapes
999 def evaluate(self, **kwargs):
1000 """
1001 Evaluates the shape.
1002 """
1003 vs = []
1004 for v in self._shapes:
1005 d = v.evaluate(**kwargs)
1006 vs.append(d)
1007 res = self._fct(*vs)
1008 if self.name is not None:
1009 res.name = self.name
1010 return res