Coverage for mlprodict/onnxrt/shape_object.py: 91%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

504 statements  

1# pylint: disable=C0302 

2""" 

3@file 

4@brief Shape object. 

5""" 

6import numpy 

7 

8 

9class BaseDimensionShape: 

10 """ 

11 Base class to @see cl DimensionObject, 

12 @see cl ShapeOperator, @see cl ShapeObject. 

13 """ 

14 

15 def to_string(self, use_x=True): 

16 """ 

17 Converts the object into a string. 

18 """ 

19 raise NotImplementedError() 

20 

21 def evaluate(self, **kwargs): 

22 """ 

23 Evaluates the object, reduces the expression 

24 to a number or a string. 

25 """ 

26 raise NotImplementedError() # pragma: no cover 

27 

28 

29class ShapeOperator(BaseDimensionShape): 

30 """ 

31 Base class for all shapes operator. 

32 """ 

33 

34 def __init__(self, name, fct, fct_string, *args): 

35 """ 

36 @param name display name of the operator 

37 @param fct function doing the operator 

38 if argument are numeric 

39 @param fct_string function represented as a string 

40 @param args argument of the operator 

41 """ 

42 self._name = name 

43 self._fct = fct 

44 self._fct_string = fct_string 

45 self._args = args 

46 for a in self._args: 

47 if not isinstance(a, DimensionObject): 

48 raise TypeError( 

49 "All arguments must be of type DimensionObject not '{}'." 

50 "".format(type(a))) 

51 

52 def __repr__(self): 

53 """ 

54 usual 

55 """ 

56 return "{0}('{1}', {2}, '{2}', {3})".format( 

57 self.__class__.__name__, self._name, 

58 self._fct_string, self._args) 

59 

60 def to_string(self, use_x=True): 

61 """ 

62 Displays as a string. 

63 

64 @return a string 

65 """ 

66 raise NotImplementedError( # pragma: no cover 

67 "Operator '{}' does not implement 'to_string': {}.".format( 

68 self.__class__.__name__, repr(self))) 

69 

70 def evaluate(self, **kwargs): 

71 """ 

72 Evalutes the operator. 

73 

74 @param kwargs value for the variables. 

75 @return string or integer 

76 """ 

77 args = [] 

78 has_string = False 

79 for a in self._args: 

80 a = DimensionObject._same_(a) 

81 v = a.evaluate(**kwargs) 

82 if isinstance(v, str): 

83 has_string = True 

84 args.append(v) 

85 if has_string: 

86 res = self._evaluate_string_(args, **kwargs) 

87 else: 

88 try: 

89 res = self._fct(*args) 

90 except TypeError as e: 

91 raise RuntimeError( 

92 "Unable to evaluate operator {} due to {}".format(repr(self), e)) from e 

93 return res 

94 

95 def _evaluate_string_(self, args, **kwargs): 

96 """ 

97 Evalutes the operator assuming some of them are still strings. 

98 

99 @param args arguments extracted by method *evaluate* 

100 @param kwargs value for the variables. 

101 @return string or integer 

102 """ 

103 raise NotImplementedError( 

104 "This function must be overwritten.") # pragma: no cover 

105 

106 

107class ShapeBinaryOperator(ShapeOperator): 

108 """ 

109 Base class for shape binary operator. 

110 """ 

111 

112 def __init__(self, name, fct, fct_string, x, y): 

113 """ 

114 @param name display name of the operator 

115 @param fct function doing the operator 

116 if argument are numeric 

117 @param fct_string function represented as a string 

118 @param x first argument 

119 @param y second argument 

120 """ 

121 ShapeOperator.__init__(self, name, fct, fct_string, x, y) 

122 if isinstance(x, tuple): 

123 raise TypeError('x cannot be a tuple') # pragma: no cover 

124 if isinstance(y, tuple): 

125 raise TypeError('y cannot be a tuple') # pragma: no cover 

126 

127 def _to_string1(self, x, y): 

128 return DimensionObject(self._fct(x._dim, y._dim)).to_string() 

129 

130 def _to_string2(self, x, y): 

131 return DimensionObject("{}{}{}".format(x._dim, self._name, y._dim)).to_string() 

132 

133 def _to_string2b(self, x, y): 

134 return DimensionObject("({}){}({})".format(x._dim, self._name, y._dim)).to_string() 

135 

136 def _to_string3(self, x): 

137 return DimensionObject("{}{}x".format(x._dim, self._name)).to_string() 

138 

139 def to_string(self, use_x=True): 

140 """ 

141 Applies binary operator to a dimension. 

142 

143 @param use_x use `'x'` if dimension is unknown 

144 @return a string 

145 """ 

146 x, y = self._args # pylint: disable=W0632 

147 if isinstance(x._dim, int): 

148 if isinstance(y, DimensionObject): 

149 if isinstance(y._dim, int): 

150 return self._to_string1(x, y) 

151 if isinstance(y._dim, str): 

152 return self._to_string2(x, y) 

153 if y._dim is None: 

154 if use_x: 

155 return self._to_string3(x) 

156 return DimensionObject("{}{}DimensionObject()".format( 

157 x._dim, self._name)).to_string() 

158 raise TypeError( # pragma: no cover 

159 "Unable to handle type '{}'.".format(type(y._dim))) 

160 raise TypeError( # pragma: no cover 

161 "Unable to handle type '{}'.".format(type(y))) 

162 elif isinstance(x._dim, str): 

163 if isinstance(y._dim, int): 

164 return self._to_string2(x, y) 

165 if isinstance(y._dim, str): 

166 return self._to_string2b(x, y) 

167 raise TypeError( # pragma: no cover 

168 "Unable to handle type '{}'.".format(type(y._dim))) 

169 raise TypeError( # pragma: no cover 

170 "Unable to handle type '{}'.".format(type(x._dim))) 

171 

172 def _evaluate_string_(self, args, **kwargs): 

173 """ 

174 Evalutes the operator assuming some of them are still strings. 

175 

176 @param args arguments extracted by method *evaluate* 

177 @param kwargs value for the variables. 

178 @return string or integer 

179 """ 

180 return self._name.join(map(lambda s: '({})'.format(s), args)) 

181 

182 

183class ShapeBinaryFctOperator(ShapeBinaryOperator): 

184 """ 

185 Base class for shape binary operator defined by a function. 

186 """ 

187 

188 def _to_string2(self, x, y): 

189 return DimensionObject("{}({},{})".format(self._name, x._dim, y._dim)).to_string() 

190 

191 def _to_string2b(self, x, y): 

192 return DimensionObject("{}({},{})".format(self._name, x._dim, y._dim)).to_string() 

193 

194 def _to_string3(self, x): 

195 return DimensionObject("{}({},x)".format(self._name, x._dim)).to_string() 

196 

197 def _evaluate_string_(self, args, **kwargs): 

198 """ 

199 Evalutes the operator assuming some of them are still strings. 

200 

201 @param args arguments extracted by method *evaluate* 

202 @param kwargs value for the variables. 

203 @return string or integer 

204 """ 

205 return "{}({})".format(self._name, ",".join(map(str, args))) 

206 

207 

208class ShapeOperatorAdd(ShapeBinaryOperator): 

209 """ 

210 Shape addition. 

211 """ 

212 

213 def __init__(self, x, y): 

214 ShapeBinaryOperator.__init__( 

215 self, '+', lambda a, b: a + b, 'lambda a, b: a + b', x, y) 

216 

217 def __repr__(self): 

218 """ 

219 Displays a string. 

220 

221 @return a string 

222 """ 

223 return "{0}({1}, {2})".format( 

224 self.__class__.__name__, repr(self._args[0]), repr(self._args[1])) 

225 

226 

227class ShapeOperatorMul(ShapeBinaryOperator): 

228 """ 

229 Shape multiplication. 

230 """ 

231 

232 def __init__(self, x, y): 

233 ShapeBinaryOperator.__init__( 

234 self, '*', lambda a, b: a * b, 'lambda a, b: a * b', x, y) 

235 

236 def __repr__(self): 

237 """ 

238 Displays a string. 

239 

240 @return a string 

241 """ 

242 return "{0}({1}, {2})".format( 

243 self.__class__.__name__, repr(self._args[0]), repr(self._args[1])) 

244 

245 

246class ShapeOperatorGreater(ShapeBinaryOperator): 

247 """ 

248 Shape comparison. 

249 """ 

250 

251 def __init__(self, x, y): 

252 ShapeBinaryOperator.__init__( 

253 self, '>', lambda a, b: a > b, 'lambda a, b: a > b', x, y) 

254 

255 def __repr__(self): 

256 """ 

257 Displays a string. 

258 

259 @return a string 

260 """ 

261 return "{0}({1}, {2})".format( 

262 self.__class__.__name__, repr(self._args[0]), repr(self._args[1])) 

263 

264 

265class ShapeOperatorMax(ShapeBinaryFctOperator): 

266 """ 

267 Best on each dimension. 

268 """ 

269 

270 def __init__(self, x, y): 

271 ShapeBinaryFctOperator.__init__( 

272 self, 'max', lambda a, b: max(a, b), 'max(a, b)', x, y) 

273 

274 def __repr__(self): 

275 """ 

276 Displays a string. 

277 

278 @return a string 

279 """ 

280 return "{0}({1}, {2})".format( 

281 self.__class__.__name__, repr(self._args[0]), repr(self._args[1])) 

282 

283 

284class DimensionObject(BaseDimensionShape): 

285 """ 

286 One dimension of a shape. 

287 """ 

288 

289 def __init__(self, obj): 

290 """ 

291 @param obj int or @see cl DimensionObject or None to 

292 specify something unknown 

293 """ 

294 if obj is None or obj == 0 or obj == '?': 

295 self._dim = None 

296 elif isinstance(obj, (int, str, ShapeOperator, DimensionObject, 

297 numpy.int32, numpy.int64)): 

298 self._dim = obj 

299 else: 

300 raise TypeError("Unexpected type for obj: {}".format(type(obj))) 

301 

302 @property 

303 def dim(self): 

304 """ 

305 Returns the dimension. 

306 """ 

307 return self._dim 

308 

309 def __repr__(self): 

310 """ 

311 usual 

312 """ 

313 if isinstance(self._dim, int): 

314 return "DimensionObject({})".format(self._dim) 

315 if isinstance(self._dim, DimensionObject): 

316 return repr(self._dim) 

317 if isinstance(self._dim, ShapeOperator): 

318 return "DimensionObject({})".format(repr(self._dim)) 

319 return "DimensionObject('{}')".format(self._dim) 

320 

321 @staticmethod 

322 def _same_(obj): 

323 """ 

324 Returns *obj* if *obj* is @see cl DimensionObject 

325 otherwise converts it. 

326 """ 

327 if isinstance(obj, DimensionObject): 

328 return obj 

329 return DimensionObject(obj) 

330 

331 def to_string(self, use_x=True): 

332 """ 

333 Represents the dimension as a string. 

334 """ 

335 if isinstance(self._dim, int): 

336 return '{}'.format(self._dim) 

337 if isinstance(self._dim, ShapeOperator): 

338 return self._dim.to_string() 

339 if isinstance(self._dim, str): 

340 return self._dim 

341 if self._dim is None: 

342 return 'x' if use_x else '?' 

343 raise NotImplementedError( # pragma: no cover 

344 "Not implemented for '{}'.".format(repr(self))) 

345 

346 def evaluate(self, **kwargs): 

347 """ 

348 Evalutes the dimension. 

349 

350 @param kwargs value for the variables. 

351 @return string or integer 

352 """ 

353 if isinstance(self._dim, (int, ShapeOperator, DimensionObject)): 

354 res = self._dim 

355 elif isinstance(self._dim, str): 

356 if self._dim in kwargs: 

357 res = kwargs[self._dim] 

358 else: 

359 res = self._dim 

360 elif self._dim is None: 

361 pref = str(hex(id(self)))[2:] 

362 res = "n{}".format(pref) 

363 elif isinstance(self._dim, ): 

364 res = self._dim.evaluate(**kwargs) 

365 else: 

366 raise NotImplementedError( # pragma: no cover 

367 "Not implemented for '{}'.".format(repr(self))) 

368 if isinstance(res, (ShapeOperator, DimensionObject)): 

369 return res.evaluate(**kwargs) 

370 return res 

371 

372 def __eq__(self, v): 

373 """ 

374 usual 

375 """ 

376 if isinstance(v, (int, str)): 

377 return self._dim == v 

378 if isinstance(v, DimensionObject): 

379 return v == self._dim 

380 if isinstance(v, ShapeOperator): 

381 ve = v.evaluate() 

382 return ve == self._dim 

383 if v is None: 

384 return self._dim is None 

385 raise TypeError( # pragma: no cover 

386 "Unable to compare a DimensionObject to {}".format(type(v))) 

387 

388 def __add__(self, obj): 

389 """ 

390 usual 

391 """ 

392 return DimensionObject( 

393 ShapeOperatorAdd(self, DimensionObject._same_(obj))) 

394 

395 def __mul__(self, obj): 

396 """ 

397 usual 

398 """ 

399 return DimensionObject( 

400 ShapeOperatorMul(self, DimensionObject._same_(obj))) 

401 

402 def __gt__(self, obj): 

403 """ 

404 usual 

405 """ 

406 if obj is None: 

407 return not isinstance(self._dim, int) 

408 if isinstance(self._dim, int) and isinstance(obj._dim, int): 

409 return self._dim > obj._dim 

410 return DimensionObject( 

411 ShapeOperatorGreater(self, DimensionObject._same_(obj))) 

412 

413 

414class ShapeObject(BaseDimensionShape): 

415 """ 

416 Handles mathematical operations around shapes. 

417 It stores a type (:epkg:`numpy` type), 

418 and a name to somehow have an idea of where 

419 the shape comes from in the :epkg:`ONNX` graph. 

420 The shape itself is defined by a list of 

421 @see cl DimensionObject or @see cl ShapeOperator 

422 or *None* if the shape is unknown. A dimension is an 

423 integer or a variable encoded as a string. This variable 

424 is a way to tell the dimension may vary. 

425 

426 .. runpython:: 

427 :showcode: 

428 :warningout: DeprecationWarning 

429 

430 import numpy 

431 from mlprodict.onnxrt.shape_object import ShapeObject 

432 

433 sh1 = ShapeObject((1, 2), dtype=numpy.float32) 

434 sh2 = ShapeObject((45, 2), dtype=numpy.float32) 

435 mx = max(sh1, sh2) 

436 print(mx) 

437 

438 sh1 = ShapeObject((1, 2), dtype=numpy.float32) 

439 sh2 = ShapeObject((None, 2), dtype=numpy.float32) 

440 print(sh2) 

441 mx = max(sh1, sh2) 

442 print(mx.to_string()) 

443 

444 sh1 = ShapeObject((1, 2), dtype=numpy.float32) 

445 sh2 = ShapeObject(('n', 2), dtype=numpy.float32) 

446 print(sh2) 

447 mx = max(sh1, sh2) 

448 print(mx.evaluate(n=4)) 

449 """ 

450 

451 def __init__(self, shape, dtype=None, use_n1=False, name=None, 

452 subtype=None): 

453 """ 

454 @param shape tuple or `numpy.array` 

455 @param dtype dtype 

456 @param use_n1 use `'n'` if the first dimension is unknown 

457 @param name optional, for debugging purposes 

458 @param subtype element type if this type is a list 

459 """ 

460 self.name = name 

461 self.subtype = subtype 

462 if isinstance(shape, numpy.ndarray): 

463 self._shape = [DimensionObject(s) for s in shape.shape] 

464 self._dtype = shape.dtype 

465 elif isinstance(shape, dict) and 'type' in shape: 

466 tshape = shape['type'] 

467 if tshape['kind'] == 'tensor': 

468 if tshape['shape'] == ('?', ): 

469 self._shape = None 

470 else: 

471 self._shape = [DimensionObject(s) for s in tshape['shape']] 

472 self._dtype = tshape['elem'] 

473 elif tshape['kind'] == 'map': 

474 self._shape = [] 

475 self._dtype = 'map' 

476 elif tshape['kind'] == 'sequence': 

477 self._shape = [] 

478 self._dtype = 'sequence' 

479 else: 

480 raise ValueError( # pragma: no cover 

481 "Wrong shape value {}".format(shape)) 

482 elif isinstance(shape, (tuple, list)): 

483 self._shape = [] 

484 for s in shape: 

485 self._shape.append(DimensionObject(s)) 

486 self._dtype = dtype 

487 elif shape is None: 

488 # shape is unknown 

489 self._shape = None 

490 self._dtype = dtype 

491 else: 

492 raise TypeError( # pragma: no cover 

493 "Unexpected type for shape: {}, shape={}".format( 

494 type(shape), shape)) 

495 

496 def _dtype_again(): 

497 if self._dtype is None: 

498 raise TypeError( 

499 "dtype cannot be None, shape type is {}\n{}".format( 

500 type(shape), shape)) 

501 if isinstance(self._dtype, numpy.dtype): 

502 # no need to go further 

503 return 

504 if self._dtype in (float, 'double', 'tensor(double)'): 

505 self._dtype = numpy.float64 

506 elif self._dtype in ('float32', 'float', 'tensor(float)'): 

507 self._dtype = numpy.float32 

508 elif self._dtype in (numpy.float16, 'float16', 'tensor(float16)'): 

509 self._dtype = numpy.float16 

510 elif self._dtype in ('int32', 'tensor(int32)'): 

511 self._dtype = numpy.int32 

512 elif self._dtype in (int, 'int', 'int64', 'tensor(int64)'): 

513 self._dtype = numpy.int64 

514 elif self._dtype in (str, 'str', numpy.str_, 'tensor(str)'): 

515 self._dtype = numpy.str_ 

516 elif (hasattr(self._dtype, 'type') and self._dtype.type is numpy.string_): 

517 pass 

518 elif self._dtype in (bool, 'bool', numpy.bool_): 

519 self._dtype = numpy.bool_ 

520 elif self._dtype in (object, numpy.object_): 

521 pass 

522 elif self._dtype in (numpy.int8, 'int8', ): 

523 self._dtype = numpy.int8 

524 elif self._dtype in (numpy.uint8, 'uint8', ): 

525 self._dtype = numpy.uint8 

526 elif self._dtype in (numpy.int16, 'int16', ): 

527 self._dtype = numpy.int16 

528 elif self._dtype in (numpy.uint16, 'uint16', ): 

529 self._dtype = numpy.uint16 

530 elif self._dtype in (numpy.uint32, 'uint32', ): 

531 self._dtype = numpy.uint32 

532 elif self._dtype in (numpy.uint64, 'uint64', ): 

533 self._dtype = numpy.uint64 

534 elif self._dtype in (numpy.complex64, 'complex64', ): 

535 self._dtype = numpy.complex64 

536 elif self._dtype in (numpy.complex128, 'complex128', ): 

537 self._dtype = numpy.complex128 

538 elif self._dtype == "tensor({'kind': 'tensor', 'elem': 'float', 'shape': })": 

539 self._dtype = numpy.float32 

540 elif self._dtype not in { 

541 numpy.float32, numpy.float64, numpy.int32, numpy.int64, 

542 numpy.str_, numpy.bool_, numpy.float16, None, 

543 numpy.complex64, numpy.complex128, 

544 'map', 'sequence'}: 

545 raise TypeError( # pragma: no cover 

546 "dtype has an unexpected value: '{}'.".format(self._dtype)) 

547 try: 

548 _dtype_again() 

549 except TypeError as e: 

550 raise TypeError( # pragma: no cover 

551 "Unexpected error with %r of type %r, name=%r." % ( 

552 (self._dtype, type(self._dtype), name))) from e 

553 

554 def _shape_again(): 

555 if self._shape is not None: 

556 for i, a in enumerate(self._shape): 

557 if not isinstance(a, DimensionObject): 

558 raise TypeError( # pragma: no cover 

559 'Dimension {} has a wrong type {}'.format( 

560 i, type(a))) 

561 if use_n1: 

562 sh = self._shape[0] if self._shape else None 

563 if isinstance(sh, DimensionObject) and sh._dim is None: 

564 sh._dim = 'n' 

565 if self._shape is not None: 

566 for s in self._shape: 

567 if isinstance(s, int): 

568 raise TypeError( # pragma: no cover 

569 "Unexpected type int in shape %r." % self) 

570 _shape_again() 

571 

572 def reshape(self, shape): 

573 """ 

574 Creates a new shape, checks the number of elements is the same. 

575 """ 

576 sh = ShapeObject(shape, self.dtype, getattr(self, '_dim', None), 

577 self.name) 

578 p1 = self.product().evaluate() 

579 p2 = sh.product().evaluate() 

580 if isinstance(p1, int) and p1 != p2: 

581 raise ValueError("Shape {} cannot be reshaped into {} " 

582 "(p1={}, p2={}).".format(sh, shape, p1, p2)) 

583 return sh 

584 

585 def copy(self, dtype=None, name=None): 

586 """ 

587 A copy not a deepcopy. 

588 

589 @param dtype None or a value to rewrite the type. 

590 @param name overwrites the name 

591 @return @see cl ShapeObject 

592 """ 

593 if self._shape is None: 

594 return ShapeObject(None, dtype=self.dtype, name=name or self.name) 

595 return ShapeObject(self._shape.copy(), 

596 self.dtype if dtype is None else dtype, 

597 name=name or self.name, 

598 subtype=self.subtype) 

599 

600 def __getitem__(self, index): 

601 """ 

602 Extracts a specific dimension. 

603 """ 

604 if self._shape is None: 

605 return None 

606 if isinstance(index, int) and index >= len(self._shape): 

607 return 1 

608 return self._shape[index] 

609 

610 def __setitem__(self, index, value): 

611 """ 

612 Changes a specific dimension. 

613 """ 

614 if self._shape is None: 

615 return 

616 while len(self._shape) <= index: 

617 self._shape.append(DimensionObject(1)) 

618 self._shape[index] = value 

619 

620 @property 

621 def shape(self): 

622 """ 

623 Returns the stored shape. 

624 """ 

625 if self._shape is None: 

626 return None 

627 return tuple(self._shape) 

628 

629 def __len__(self): 

630 """ 

631 Returns the number of dimensions. 

632 """ 

633 if self._shape is None: 

634 return 0 

635 return len(self._shape) 

636 

637 @property 

638 def dtype(self): 

639 """ 

640 Returns the stored *dtype*. 

641 """ 

642 return self._dtype 

643 

644 def reduce(self, axis=1, keepdims=False, dtype=None): 

645 """ 

646 Reduces the matrix. Removes one dimension. 

647 

648 @param axis axis 

649 @param keepdims keep dimensions, replaces the removed 

650 dimension by 1 

651 @param dtype if not None, changes the type 

652 @return new dimension 

653 """ 

654 if self._shape is None: 

655 if self.name is None: 

656 return self.copy() 

657 return self.copy(name="{}-RD".format(self.name)) 

658 if axis is None: 

659 return ShapeObject((1, ), self._dtype if dtype is None else dtype, 

660 name="{}-RDN".format(self.name)) 

661 

662 if isinstance(axis, ShapeObject): 

663 

664 def drop_axis(shape, a): 

665 c = list(shape) 

666 del c[a[0]] 

667 return c 

668 

669 return ShapeObjectFct( 

670 drop_axis, self, axis, name="DropAxis", dtype=self.dtype) 

671 

672 if axis < 0: 

673 axis = len(self._shape) + axis 

674 if 0 <= axis < len(self._shape): 

675 cp = self._shape.copy() 

676 if keepdims: 

677 cp[axis] = DimensionObject(1) 

678 else: 

679 del cp[axis] 

680 return ShapeObject(cp, self._dtype if dtype is None else dtype, 

681 name="{}-RD".format(self.name)) 

682 raise IndexError("axis={} is wrong, shape is {}-tuple and equal to " 

683 "{}".format(axis, len(self._shape), self)) 

684 

685 def __repr__(self): 

686 """ 

687 usual 

688 """ 

689 st = str(self.dtype) 

690 if "'" in st: 

691 st = st.split("'")[1] 

692 

693 if self.shape is None: 

694 if self.name is None: 

695 return "ShapeObject(None, dtype={})".format(st) 

696 return "ShapeObject(None, dtype={}, name='{}')".format(st, self.name) 

697 

698 st_shape = [] 

699 for s in self.shape: 

700 if isinstance(getattr(s, "_dim", None), (int, str)): 

701 st_shape.append(str(s._dim)) 

702 else: 

703 st_shape.append(repr(s)) 

704 if len(st_shape) == 1: 

705 st_shape.append('') 

706 st_shape = '({})'.format(", ".join(st_shape)) 

707 if self.name is None: 

708 return "ShapeObject({}, dtype={})".format(st_shape, st) 

709 return "ShapeObject({}, dtype={}, name='{}')".format( 

710 st_shape, st, self.name) 

711 

712 def __iter__(self): 

713 """ 

714 Iterators over dimensions. 

715 """ 

716 if self._shape is not None: 

717 for d in self._shape: 

718 yield d 

719 

720 def __gt__(self, a): 

721 """ 

722 Compares shapes. Operator ``>``. 

723 """ 

724 if isinstance(a, tuple): 

725 a = ShapeObject(a, dtype=self._dtype) 

726 if self._shape is None and a._shape is None: 

727 return False 

728 if self._shape is None: 

729 return True 

730 if a._shape is None: 

731 return False 

732 if len(self) > len(a): 

733 return True 

734 if len(self) < len(a): 

735 return False 

736 for d1, d2 in zip(self, a): 

737 if d1 > d2: 

738 return True 

739 if d1 < d2: 

740 return False 

741 return False 

742 

743 def __eq__(self, a): 

744 """ 

745 Tests equality between two shapes. 

746 """ 

747 if isinstance(a, tuple): 

748 a = ShapeObject(a, dtype=self._dtype) 

749 if self._shape is None and a._shape is None: 

750 return True 

751 if self._shape is None or a._shape is None: 

752 return False 

753 if len(self) != len(a): 

754 return False 

755 for d1, d2 in zip(self, a): 

756 if d1 == d2: 

757 continue 

758 return False 

759 return True 

760 

761 def evaluate(self, **kwargs): 

762 """ 

763 Evaluates the shape. 

764 """ 

765 vs = [] 

766 for v in self: 

767 d = v.evaluate(**kwargs) 

768 vs.append(d) 

769 return ShapeObject(tuple(vs), self._dtype, name="{}-EV".format(self.name)) 

770 

771 def to_string(self, use_x=False): 

772 """ 

773 Converts shapes into a string. 

774 """ 

775 shapes = [] 

776 for a in self._shape: 

777 shapes.append(a.to_string(use_x=use_x)) 

778 return '({})'.format(', '.join(shapes)) 

779 

780 def product(self): 

781 """ 

782 Multiplies all the dimension. 

783 

784 @return @see cl DimensionObject 

785 """ 

786 cl = self[0] 

787 for i in range(1, len(self)): 

788 cl = cl * self[i] 

789 return cl 

790 

791 def append(self, dim): 

792 """ 

793 Appends a dimension. 

794 """ 

795 if self._shape is None: 

796 return 

797 if isinstance(dim, DimensionObject): 

798 self._shape.append(dim) 

799 else: 

800 self._shape.append(DimensionObject(dim)) 

801 

802 def insert(self, dim, pos=0): 

803 """ 

804 Inserts a dimension at position *pos*. 

805 """ 

806 if self._shape is None: 

807 return 

808 if isinstance(dim, DimensionObject): 

809 self._shape.insert(pos, dim) 

810 else: 

811 self._shape.insert(pos, DimensionObject(dim)) 

812 

813 def squeeze(self, axis): 

814 """ 

815 Removes one dimension. 

816 """ 

817 cp = self.copy(name='{}-SZ'.format(self.name)) 

818 cp.drop_axis(axis) 

819 return cp 

820 

821 def unsqueeze(self, axes): 

822 """ 

823 Adds dimensions. 

824 """ 

825 cp = self 

826 name = '{}-USZ'.format(self.name) 

827 for ax in axes[::-1]: 

828 cp = cp.copy(name=name) 

829 cp.insert(ax, 1) 

830 return cp 

831 

832 def transpose(self, perm): 

833 """ 

834 Removes one dimension. 

835 """ 

836 if self.shape is None: 

837 return self.copy(name='{}-TR'.format(self.name)) 

838 cp = ShapeObject([None for p in perm], dtype=self.dtype, 

839 name="{}-TR".format(self.name)) 

840 for i, p in enumerate(perm): 

841 if p >= len(self): 

842 # This should not happen. 

843 cp._shape[i] = None 

844 else: 

845 cp._shape[i] = self._shape[p] 

846 return cp 

847 

848 def drop_axis(self, axis): 

849 """ 

850 Drops an axis. 

851 """ 

852 if self._shape is not None: 

853 if isinstance(axis, (tuple, list)): 

854 for i in sorted(axis, reverse=True): 

855 del self._shape[i] 

856 else: 

857 del self._shape[axis] 

858 

859 def broadcast(self, a): 

860 """ 

861 Computes the shape after a broadcast. 

862 """ 

863 if a is None: 

864 raise ValueError("a should not be None") # pragma: no cover 

865 if a._shape is None: 

866 return a.copy() 

867 if self._shape is None: 

868 return self.copy() 

869 mx = max(len(self._shape), len(a._shape)) 

870 res = [] 

871 for i in range(mx): 

872 if i < len(self._shape): 

873 if i < len(a._shape): 

874 res.append(ShapeOperatorMax(self[i], a[i])) 

875 else: 

876 res.append(self[i]) 

877 else: 

878 res.append(a[i]) 

879 return ShapeObject(tuple(res), self.dtype, False, 

880 name="broadcast-{}-{}".format(self.name, a.name)) 

881 

882 @staticmethod 

883 def _infer_merged_type(*args, use_dtype=True): 

884 if use_dtype: 

885 tys = set(a.dtype for a in args) 

886 else: 

887 tys = set(args) 

888 if len(tys) == 1: 

889 return list(tys)[0] 

890 if any(tys & {numpy.float64, numpy.int64, 

891 numpy.float32, numpy.int32, 

892 numpy.float16}): 

893 return numpy.float64 

894 raise RuntimeError( # pragma: no cover 

895 "Unable to infer types based on {} ({}).".format( 

896 tys, len(tys))) 

897 

898 def concat_columns(self, axis, *shapes): 

899 """ 

900 Concatenates columns from *shapes* to this one 

901 along one axis. 

902 """ 

903 args = [self] + list(shapes) 

904 dtype = self._infer_merged_type(*args) 

905 dim_axis = self[axis] 

906 if isinstance(dim_axis, int): 

907 dim_axis = DimensionObject(dim_axis) 

908 if dim_axis is None: 

909 return ShapeObject(None, dtype=dtype) 

910 if isinstance(dim_axis, int): 

911 raise TypeError( # pragma: no cover 

912 "Unexpected type for shape %r." % self) 

913 for a in shapes: 

914 if a[axis] is None: 

915 return ShapeObject(None, dtype=dtype) 

916 dim_axis = dim_axis + a[axis] 

917 a0 = args[0].copy(dtype=dtype) 

918 a0[axis] = dim_axis 

919 return a0 

920 

921 @staticmethod 

922 def einsum_shape(equation, *inputs): 

923 """ 

924 Computes :epkg:`einsum` shapes. 

925 Not the most efficient one as it creates variables 

926 of the given shapes. 

927 """ 

928 for inp in inputs: 

929 if inp.shape is None: 

930 return inp 

931 if b"->" not in equation: 

932 raise RuntimeError( # pragma: no cover 

933 "Equation %r does not have '->'.") 

934 inp, out = [_.strip() for _ in equation.split(b"->")] 

935 inps = [_.strip() for _ in inp.split(b',')] 

936 if len(inputs) != len(inps): 

937 raise RuntimeError( # pragma: no cover 

938 "Input mismatch between '{}' and {}.".format(equation, inps)) 

939 shs = {} 

940 for a, b in zip(inps, inputs): 

941 if len(a) != len(b): 

942 raise RuntimeError( # pragma: no cover 

943 "Input mismatch '{}' (in '{}') and {}.".format(a, equation, b)) 

944 for c, s in zip(a, b): 

945 if c not in shs: 

946 shs[c] = s 

947 elif shs[c] != s: 

948 raise RuntimeError( # pragma: no cover 

949 "Equation '{}'. Dimension mismatch '{}' != {}.".format( 

950 equation, s, shs[c])) 

951 new_shape = [shs[i] for i in out] 

952 return ShapeObject(new_shape, dtype=ShapeObject._infer_merged_type(*inputs)) 

953 

954 @staticmethod 

955 def gather_shape(input, indices, axis): 

956 """ 

957 Computes Gather shapes. 

958 """ 

959 input_rank = len(input) 

960 if input_rank is None: 

961 return ShapeObject(None, dtype=input._dtype) 

962 index_rank = len(indices) 

963 if index_rank is None: 

964 return ShapeObject(None, dtype=input._dtype) 

965 

966 if axis < 0: 

967 axis = input_rank + axis 

968 

969 shape = [] 

970 for i in range(axis): 

971 shape.append(input[i]) 

972 

973 for dim in indices: 

974 shape.append(dim) 

975 

976 for i in range(axis + 1, input_rank): 

977 shape.append(input[i]) 

978 

979 return ShapeObject(shape, dtype=input._dtype) 

980 

981 

982class ShapeObjectFct(ShapeObject): 

983 """ 

984 Computes a shape depending on a user defined function. 

985 See @see cl Conv for an example. 

986 """ 

987 

988 def __init__(self, fct, *shapes, dtype=None, name=None): 

989 """ 

990 @param fct function 

991 @param shapes shapes sent to fct 

992 @param dtype dtype 

993 @param name optional, for debugging purposes 

994 """ 

995 ShapeObject.__init__(self, None, dtype=dtype, name=name) 

996 self._fct = fct 

997 self._shapes = shapes 

998 

999 def evaluate(self, **kwargs): 

1000 """ 

1001 Evaluates the shape. 

1002 """ 

1003 vs = [] 

1004 for v in self._shapes: 

1005 d = v.evaluate(**kwargs) 

1006 vs.append(d) 

1007 res = self._fct(*vs) 

1008 if self.name is not None: 

1009 res.name = self.name 

1010 return res