Coverage for mlprodict/npy/onnx_variable.py: 94%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

418 statements  

1""" 

2@file 

3@brief Intermediate class between :epkg:`numpy` and :epkg:`onnx`. 

4 

5.. versionadded:: 0.6 

6""" 

7import logging 

8import numpy 

9from onnx.helper import make_tensor 

10from ..onnx_tools.onnx2py_helper import guess_proto_dtype 

11from .xop_variable import Variable 

12from .xop import loadop, OnnxOperatorItem, OnnxOperatorTuple 

13from .xop_variable import guess_numpy_type 

14 

15logger = logging.getLogger('xop') 

16 

17 

18try: 

19 numpy_bool = numpy.bool_ 

20except AttributeError: # pragma: no cover 

21 numpy_bool = bool 

22try: 

23 numpy_str = numpy.str_ 

24except AttributeError: # pragma: no cover 

25 numpy_str = str 

26 

27 

28class OnnxVar: 

29 """ 

30 Variables used into :epkg:`onnx` computation. 

31 

32 :param inputs: variable name or object 

33 :param op: :epkg:`ONNX` operator 

34 :param select_output: if multiple output are returned by 

35 ONNX operator *op*, it takes only one specifed by this 

36 argument 

37 :param dtype: specifies the type of the variable 

38 held by this class (*op* is None) in that case 

39 :param kwargs: addition argument to give operator *op* 

40 

41 .. versionadded:: 0.6 

42 """ 

43 __array_ufunc__ = None 

44 

45 def __init__(self, *inputs, op=None, select_output=None, 

46 dtype=None, **kwargs): 

47 logger.debug('OnnxVar(%d in, dtype=%r, op=%r, select_output=%r)', 

48 len(inputs), dtype, op, select_output) 

49 self.inputs = inputs 

50 self.select_output = select_output 

51 self.onnx_op = op 

52 self.alg_ = None 

53 self.onnx_op_kwargs = kwargs 

54 if dtype is not None and (op is not None or len(inputs) != 1): 

55 raise RuntimeError( # pragma: no cover 

56 "dtype can only be used if op is None or len(inputs) == 1.") 

57 for i, inp in enumerate(self.inputs): 

58 if isinstance(inp, type): 

59 raise TypeError( # pragma: no cover 

60 "Unexpected type for input %d - %r." % (i, inp)) 

61 if not isinstance(inp, numpy.ndarray): 

62 continue 

63 if (inp.size > 0 and 

64 isinstance(inp.ravel()[0], (numpy.ndarray, OnnxVar))): 

65 raise TypeError( # pragma: no cover 

66 "Unexpected type for input %d: %r, %r, " 

67 "op=%r" % (i, type(inp), inp.ravel()[0], op)) 

68 self.dtype = self._guess_dtype(dtype, from_init=True) 

69 

70 def _guess_dtype(self, dtype, from_init=False): 

71 "Guesses dtype when not specified." 

72 if dtype is not None: 

73 return dtype 

74 dtypes = [] 

75 for i, inp in enumerate(self.inputs): 

76 if isinstance(inp, str): 

77 return None 

78 if isinstance(inp, numpy.ndarray): 

79 dtypes.append(inp.dtype) 

80 elif isinstance(inp, Variable): 

81 dtypes.append(inp.dtype) 

82 elif isinstance(inp, OnnxVar): 

83 dtypes.append(inp.dtype) 

84 elif isinstance(inp, MultiOnnxVar): 

85 dtypes.append(inp._guess_dtype(dtype)) 

86 elif isinstance(inp, (numpy.float32, numpy.float64, 

87 numpy.int32, numpy.int64)): 

88 dtypes.append(inp.dtype) 

89 elif isinstance(inp, numpy_str): 

90 dtypes.append(numpy_str) 

91 elif isinstance(inp, numpy_bool): 

92 dtypes.append(numpy_bool) 

93 elif isinstance(inp, int): 

94 dtypes.append(numpy.int64) # pragma: no cover 

95 elif isinstance(inp, float): 

96 dtypes.append(numpy.float64) 

97 elif hasattr(inp, 'fit'): 

98 # scikit-learn model 

99 continue 

100 elif hasattr(inp, '_guess_dtype'): 

101 dtypes.append(inp._guess_dtype(dtype)) 

102 else: 

103 try: 

104 dtype = guess_numpy_type(inp) 

105 except NotImplementedError as e: 

106 raise TypeError( # pragma: no cover 

107 "Unexpected type for input %i type=%r." % ( 

108 i, type(inp))) from e 

109 dtypes.append(dtype) 

110 dtypes = [_ for _ in dtypes if _ is not None] 

111 unique = set(dtypes) 

112 if len(unique) != 1: 

113 return None 

114 return dtypes[0] 

115 

116 def __repr__(self): 

117 "usual" 

118 args = [] 

119 for inp in self.inputs: 

120 args.append(repr(inp)) 

121 if self.onnx_op is not None: 

122 if isinstance(self.onnx_op, str): 

123 args.append("op=%r" % self.onnx_op) 

124 else: 

125 args.append("op=%s" % self.onnx_op.__name__) 

126 if self.select_output is not None: 

127 args.append("select_output=%r" % self.select_output) 

128 if self.dtype is not None and self.dtype != self._guess_dtype(None): 

129 args.append("dtype=%r" % self.dtype) 

130 for k, v in sorted(self.onnx_op_kwargs.items()): 

131 args.append("%s=%r" % (k, v)) 

132 res = "%s(%s)" % (self.__class__.__name__, ", ".join(args)) 

133 return res 

134 

135 def set_onnx_name(self, name_type): 

136 """ 

137 Forces this variable to get this name during 

138 

139 :param name_type: a tuple *(name, type)* 

140 """ 

141 self.onnx_input_type_ = name_type 

142 

143 def to_algebra(self, op_version=None): 

144 """ 

145 Converts the variable into an operator. 

146 """ 

147 if self.alg_ is not None: 

148 return self.alg_ 

149 

150 if self.onnx_op is None: 

151 logger.debug('OnnxVar.to_algebra:1(op_version=%r)', op_version) 

152 if len(self.inputs) != 1: 

153 raise RuntimeError( # pragma: no cover 

154 "Unexpected number of inputs, 1 expected, " 

155 "got {} instead.".format(self.inputs)) 

156 if self.dtype is None or hasattr(self.inputs[0], 'onnx_name'): 

157 self.alg_ = Variable.from_skl2onnx(self.inputs[0]) 

158 elif isinstance(self.inputs[0], Variable): 

159 self.alg_ = self.inputs[0] 

160 else: 

161 self.alg_ = Variable(self.inputs[0], self.dtype) 

162 else: 

163 logger.debug('OnnxVar.to_algebra:2(op_version=%r) - onnx_op=%r', 

164 op_version, self.onnx_op) 

165 if isinstance(self.onnx_op, str): 

166 var = self._custom_op(*self.inputs, op_version=op_version, 

167 **self.onnx_op_kwargs) 

168 alg = var.to_algebra(op_version=op_version) 

169 if not hasattr(self, 'alg_'): 

170 raise RuntimeError( # pragma: no cover 

171 "Missing attribute 'alg_'.") 

172 self.alg_ = alg 

173 return alg 

174 

175 new_inputs = [] 

176 for inp in self.inputs: 

177 if hasattr(inp, 'fit'): 

178 # scikit-learn model 

179 new_inputs.append(inp) 

180 elif isinstance(inp, ( 

181 int, float, str, numpy.ndarray, numpy.int32, 

182 numpy.int64, numpy.float32, numpy.float64, 

183 numpy_bool, numpy_str, numpy.int8, numpy.uint8, 

184 numpy.int16, numpy.uint16, numpy.uint32, 

185 numpy.uint64)): 

186 if (inp.size > 0 and 

187 isinstance( 

188 inp.ravel()[0], # pylint: disable=E1101 

189 (numpy.ndarray, OnnxVar))): 

190 raise TypeError( # pragma: no cover 

191 "Unexpected type for an input %r, %r." 

192 "" % (type(inp), inp.ravel()[0])) # pylint: disable=E1101 

193 new_inputs.append(inp) 

194 else: 

195 new_inputs.append( 

196 inp.to_algebra(op_version=op_version)) 

197 

198 res = self.onnx_op(*new_inputs, op_version=op_version, 

199 **self.onnx_op_kwargs) 

200 if self.select_output is None: 

201 self.alg_ = res 

202 else: 

203 self.alg_ = res[self.select_output] 

204 return self.alg_ 

205 

206 def _custom_op(self, *args, op_version=None, runtime=None, **kwargs): 

207 """ 

208 This could be handled before a call to this method 

209 but this method can change the conversion of an non-existing 

210 operator depending on the given opset. 

211 """ 

212 if self.onnx_op == 'filter': 

213 return self._custom_op_filter(*args, op_version=op_version, 

214 runtime=runtime, **kwargs) 

215 raise NotImplementedError( # pragma: no cover 

216 "Unexpected custom operator %r." % self.onnx_op) 

217 

218 def _custom_op_filter(self, *args, op_version=None, runtime=None, **kwargs): 

219 """ 

220 This could be handled before a call to this method 

221 but this method can change the conversion of an non-existing 

222 operator depending on the given opset. 

223 """ 

224 OnnxSqueeze, OnnxTopK, OnnxGather, OnnxReduceSum = loadop( 

225 'Squeeze', 'TopK', 'Gather', 'ReduceSum') 

226 if len(args) != 2: 

227 raise RuntimeError( # pragma: no cover 

228 "Custom op 'filter' expects two inputs not %r." % len(args)) 

229 if len(kwargs) != 0: 

230 raise RuntimeError( # pragma: no cover 

231 "Custom op 'filter' expects no arguments but got %r." % kwargs) 

232 mat, index = args 

233 cast = OnnxVar(index.astype(numpy.int64), op=OnnxSqueeze) 

234 n1 = OnnxVar(cast, op=OnnxReduceSum, keepdims=1) 

235 indices = OnnxVar(cast, n1, op=OnnxTopK, select_output=1) 

236 return OnnxVar(mat, indices, op=OnnxGather) 

237 

238 @property 

239 def T(self): 

240 "Transpose." 

241 OnnxTranspose = loadop('Transpose') 

242 return OnnxVar(self, op=OnnxTranspose) 

243 

244 def astype(self, dtype): 

245 "Cast" 

246 OnnxCast = loadop('Cast') 

247 return OnnxVar(self, op=OnnxCast, to=guess_proto_dtype(dtype)) 

248 

249 @property 

250 def shape(self): 

251 "Shape" 

252 OnnxShape = loadop('Shape') 

253 return OnnxVar(self, op=OnnxShape) 

254 

255 @property 

256 def size(self): 

257 "Size" 

258 OnnxSize = loadop('Size') 

259 return OnnxVar(self, op=OnnxSize) 

260 

261 def reshape(self, shape): 

262 "Reshape" 

263 OnnxReshape = loadop('Reshape') 

264 if isinstance(shape, (tuple, list)): 

265 shape = numpy.array(shape, dtype=numpy.int64) 

266 return OnnxVar(self, shape, op=OnnxReshape) 

267 

268 def _make_array(self, y): 

269 """Converts *y* into an array if not.""" 

270 if isinstance(y, (numpy.ndarray, OnnxVar)): 

271 return y 

272 if hasattr(y, 'dtype'): 

273 return numpy.full((1, ), y, dtype=y.dtype) 

274 if isinstance(y, str): 

275 return numpy.array([y]) 

276 if isinstance(y, float): 

277 return numpy.array([y], dtype=numpy.float32) 

278 if isinstance(y, int): 

279 return numpy.array([y], dtype=numpy.int64) 

280 return y 

281 

282 def __add__(self, y): 

283 "Addition." 

284 y = self._make_array(y) 

285 OnnxAdd = loadop('Add') 

286 return OnnxVar(self, y, op=OnnxAdd) 

287 

288 def __radd__(self, y): 

289 "Right Addition." 

290 y = self._make_array(y) 

291 OnnxIdentity, OnnxAdd = loadop('Identity', 'Add') 

292 return OnnxVar(OnnxVar(y, op=OnnxIdentity), self, op=OnnxAdd) 

293 

294 def __sub__(self, y): 

295 "Subtraction." 

296 y = self._make_array(y) 

297 OnnxSub = loadop('Sub') 

298 return OnnxVar(self, y, op=OnnxSub) 

299 

300 def __rsub__(self, y): 

301 "Right subtraction." 

302 y = self._make_array(y) 

303 OnnxIdentity, OnnxSub = loadop('Identity', 'Sub') 

304 return OnnxVar(OnnxVar(y, op=OnnxIdentity), self, op=OnnxSub) 

305 

306 def __mul__(self, y): 

307 "Multiplication." 

308 y = self._make_array(y) 

309 OnnxMul = loadop('Mul') 

310 return OnnxVar(self, y, op=OnnxMul) 

311 

312 def __rmul__(self, y): 

313 "Right multiplication." 

314 y = self._make_array(y) 

315 OnnxIdentity = loadop('Identity') 

316 return OnnxVar(y, op=OnnxIdentity) * self 

317 

318 def __pow__(self, y): 

319 "Power." 

320 y = self._make_array(y) 

321 OnnxPow = loadop('Pow') 

322 return OnnxVar(self, y, op=OnnxPow) 

323 

324 def __mod__(self, y): 

325 "Modulo." 

326 y = self._make_array(y) 

327 OnnxMod = loadop('Mod') 

328 return OnnxVar(self, y, op=OnnxMod) 

329 

330 def __matmul__(self, y): 

331 "Matrix multiplication." 

332 y = self._make_array(y) 

333 OnnxMatMul = loadop('MatMul') 

334 return OnnxVar(self, y, op=OnnxMatMul) 

335 

336 def __truediv__(self, y): 

337 "Division, no difference between `/` and `//`." 

338 y = self._make_array(y) 

339 OnnxDiv = loadop('Div') 

340 return OnnxVar(self, y, op=OnnxDiv) 

341 

342 def __rtruediv__(self, y): 

343 "Division, no difference between `/` and `//`." 

344 y = self._make_array(y) 

345 OnnxIdentity, OnnxDiv = loadop('Identity', 'Div') 

346 return OnnxVar(OnnxVar(y, op=OnnxIdentity), self, op=OnnxDiv) 

347 

348 def __floordiv__(self, y): 

349 "Division, no difference between `/` and `//`." 

350 y = self._make_array(y) 

351 OnnxDiv = loadop('Div') 

352 return OnnxVar(self, y, op=OnnxDiv) 

353 

354 def __eq__(self, y): 

355 "Equality." 

356 y = self._make_array(y) 

357 OnnxEqual = loadop('Equal') 

358 return OnnxVar(self, y, op=OnnxEqual) 

359 

360 def __ne__(self, y): 

361 "Difference." 

362 y = self._make_array(y) 

363 OnnxEqual, OnnxNot = loadop('Equal', 'Not') 

364 return OnnxVar(OnnxVar(self, y, op=OnnxEqual), op=OnnxNot) 

365 

366 def __ge__(self, y): 

367 "Greater or Equal." 

368 y = self._make_array(y) 

369 OnnxGreaterOrEqual = loadop('GreaterOrEqual') 

370 return OnnxVar(self, y, op=OnnxGreaterOrEqual) 

371 

372 def __gt__(self, y): 

373 "Greater." 

374 y = self._make_array(y) 

375 OnnxGreater = loadop('Greater') 

376 return OnnxVar(self, y, op=OnnxGreater) 

377 

378 def __invert__(self): 

379 "not." 

380 OnnxNot = loadop('Not') 

381 return OnnxVar(self, op=OnnxNot) 

382 

383 def __le__(self, y): 

384 "Less or Equal." 

385 y = self._make_array(y) 

386 OnnxLessOrEqual = loadop('LessOrEqual') 

387 return OnnxVar(self, y, op=OnnxLessOrEqual) 

388 

389 def __lt__(self, y): 

390 "Less." 

391 y = self._make_array(y) 

392 OnnxLess = loadop('Less') 

393 return OnnxVar(self, y, op=OnnxLess) 

394 

395 def __and__(self, y): 

396 "And." 

397 y = self._make_array(y) 

398 OnnxAnd = loadop('And') 

399 return OnnxVar(self, y, op=OnnxAnd) 

400 

401 def __or__(self, y): 

402 "And." 

403 y = self._make_array(y) 

404 OnnxOr = loadop('Or') 

405 return OnnxVar(self, y, op=OnnxOr) 

406 

407 def not_(self): 

408 "Not." 

409 OnnxNot = loadop('Not') 

410 return OnnxVar(self, op=OnnxNot) 

411 

412 def __neg__(self): 

413 "Neg." 

414 OnnxNeg = loadop('Neg') 

415 return OnnxVar(self, op=OnnxNeg) 

416 

417 def __getitem__(self, index): 

418 """ 

419 Deals with multiple scenarios. 

420 

421 * *index* is an integer or a slice, a tuple of integers and slices, 

422 example: `[0, 1]`, `[:5, :6]`, `[::2]` (**scenario 1**) 

423 * *index* is an *ONNX* object (more precisely an instance of 

424 @see cl OnnxVar), then the method assumes it is an array of 

425 boolean to select a subset of the tensor along the first axis, 

426 example: `mat[mat == 0]` (**scenario 2**) 

427 """ 

428 if isinstance(index, OnnxVar): 

429 # scenario 2 

430 return OnnxVar(self, index, op='filter') 

431 

432 if isinstance(index, int): 

433 # Use Gather instead. 

434 OnnxGather = loadop('Gather') 

435 return OnnxVar( 

436 self, numpy.array(index, dtype=numpy.int64), 

437 axis=0, op=OnnxGather) 

438 

439 if not isinstance(index, tuple): 

440 index = (index, ) 

441 

442 # only one integer? 

443 ni = None 

444 ax = None 

445 for i, a in enumerate(index): 

446 if isinstance(a, int): 

447 if ni is None: 

448 ni = i 

449 ax = a 

450 else: 

451 ax = None 

452 ni = None 

453 break 

454 if (isinstance(a, slice) and a.start is None and 

455 a.stop is None and a.step is None): 

456 continue 

457 ax = None 

458 ni = None 

459 break 

460 if ni is not None and ax is not None: 

461 # Use Gather instead. 

462 OnnxGather = loadop('Gather') 

463 return OnnxVar( 

464 self, numpy.array(ni, dtype=numpy.int64), 

465 axis=ax, op=OnnxGather) 

466 

467 # scenario 1 

468 starts = [] 

469 ends = [] 

470 axes = [] 

471 steps = [] 

472 axis_squeeze = [] 

473 needs_shape = [] 

474 for i, ind in enumerate(index): 

475 if isinstance(ind, int): 

476 starts.append(ind) 

477 ends.append(ind + 1) 

478 axes.append(i) 

479 steps.append(1) 

480 axis_squeeze.append(i) 

481 continue 

482 if isinstance(ind, slice): 

483 if ind.start is None and ind.stop is None and ind.step is None: 

484 continue 

485 start = 0 if ind.start is None else ind.start 

486 end = (None, i) if ind.stop is None else ind.stop 

487 step = 1 if ind.step is None else ind.step 

488 starts.append(start) 

489 ends.append(end) 

490 axes.append(i) 

491 steps.append(step) 

492 if isinstance(end, tuple): 

493 needs_shape.append(len(ends) - 1) 

494 elif isinstance(end, OnnxVar): 

495 needs_shape.append(end) 

496 continue 

497 raise NotImplementedError( # pragma: no cover 

498 "Not implemented for type %r." % type(ind)) 

499 

500 if max(steps) == min(steps) == 1: 

501 steps = None 

502 else: 

503 steps = numpy.array(steps, dtype=numpy.int64) 

504 

505 starts = numpy.array(starts, dtype=numpy.int64) 

506 axes = numpy.array(axes, dtype=numpy.int64) 

507 

508 OnnxGather, OnnxSlice, OnnxSqueeze, OnnxConcat = loadop( 

509 'Gather', 'Slice', 'Squeeze', 'Concat') 

510 if len(needs_shape) > 0: 

511 shape = self.shape 

512 conc = [] 

513 for e in ends: 

514 if isinstance(e, tuple): 

515 conc.append( 

516 OnnxVar(shape, numpy.array([e[1]], numpy.int64), 

517 op=OnnxGather)) 

518 elif isinstance(e, OnnxVar): 

519 conc.append( 

520 e.reshape(numpy.array([-1], dtype=numpy.int64))) 

521 else: 

522 conc.append(numpy.array([e], dtype=numpy.int64)) 

523 if len(conc) > 1: 

524 ends = OnnxVar(*conc, op=OnnxConcat, axis=0) 

525 else: 

526 ends = conc[0] 

527 else: 

528 ends = numpy.array(ends, dtype=numpy.int64) 

529 

530 if steps is None: 

531 sliced = OnnxVar(self, starts, ends, axes, op=OnnxSlice) 

532 else: 

533 sliced = OnnxVar(self, starts, ends, axes, steps, op=OnnxSlice) 

534 if len(axis_squeeze) > 0: 

535 return OnnxVar( 

536 sliced, numpy.array(axis_squeeze, dtype=numpy.int64), 

537 op=OnnxSqueeze) 

538 return sliced 

539 

540 def __setitem__(self, index, value): 

541 """ 

542 Only supports vectors (1D tensor). 

543 

544 * *index* is an integer or a slice, a tuple of integers and slices, 

545 example: `[0]`, `[:5]`, `[::2]` (**scenario 1**) 

546 * *index* is an *ONNX* object (more precisely an instance of 

547 @see cl OnnxVar), then the method assumes it is an array of 

548 boolean to select a subset of the tensor along the first axis, 

549 example: `mat[mat == 0]` (**scenario 2**) 

550 This processing is applied before the operator it contains. 

551 A copy should be made (Identity node or copy method). 

552 """ 

553 OnnxIdentity = loadop('Identity') 

554 if self.onnx_op is not None and self.onnx_op is not OnnxIdentity: 

555 raise RuntimeError( # pragma: no cover 

556 "A copy should be made before setting new values on a matrix. " 

557 "Method copy() would do that.") 

558 

559 if isinstance(index, OnnxVar): 

560 # scenario 2, example: cp[x < 0] = -1 

561 return self._setitem2i_(index, value) 

562 elif not isinstance(index, tuple): 

563 index = (index, ) 

564 

565 for i in index: 

566 if isinstance(i, OnnxVar): 

567 raise NotImplementedError( # pragma: no cover 

568 "Unable to handle case such as cp[0, x < 0] = -1.") 

569 

570 # scenario 1 

571 if len(index) == 1: 

572 return self._setitem1i_(index[0], value) 

573 raise NotImplementedError( # pragma: no cover 

574 "Indices in %d dimensions are not implemented yet." % len(index)) 

575 

576 def _setitem1i_(self, index, value): 

577 sl = None 

578 if isinstance(index, slice): 

579 start = 0 if index.start is None else index.start 

580 stop = index.stop 

581 step = index.step 

582 sl = [start, stop, step] 

583 elif isinstance(index, int): 

584 sl = [index, index + 1, 1] 

585 else: 

586 raise NotImplementedError( # pragma: no cover 

587 "Unable to assign new values due to unexpected type %r." 

588 "" % type(index)) 

589 

590 if sl[1] is None and isinstance(value, numpy.ndarray): 

591 sl[1] = sl[0] + value.size 

592 OnnxConstantOfShape, OnnxScatterElements = loadop( 

593 'ConstantOfShape', 'ScatterElements') 

594 if sl[1] is None: 

595 if sl[2] is not None and sl[2] != 1: 

596 raise NotImplementedError( # pragma: no cover 

597 "If the length is not known, step must be 1 not %d." % sl[2]) 

598 value = make_tensor( 

599 "value", guess_proto_dtype(value.dtype), (1, ), [value]) # pylint: disable=E1101 

600 inp = self.inputs[0] 

601 if not isinstance(inp, OnnxVar): 

602 raise RuntimeError( # pragma: no cover 

603 "Input must be an instance of OnnxVar not %r." % type(inp)) 

604 cst = OnnxVar(inp.shape, op=OnnxConstantOfShape, value=value) 

605 ext = inp[:sl[0]] 

606 indices = numpy.arange(0, sl[0]).astype(numpy.int64) 

607 add_step = OnnxVar(cst, indices, ext, 

608 op=OnnxScatterElements, axis=0) 

609 else: 

610 indices = numpy.arange(sl[0], sl[1], sl[2]).astype(numpy.int64) 

611 if isinstance(value, numpy.ndarray): 

612 values = value 

613 else: 

614 values = numpy.full(indices.shape, value) 

615 add_step = OnnxVar(self.inputs[0], indices, values, 

616 op=OnnxScatterElements, axis=0) 

617 

618 self.inputs = [add_step] 

619 return self 

620 

621 def _setitem2i_(self, index, value): 

622 OnnxWhere = loadop('Where') 

623 add_step = OnnxVar(index, value, self.inputs[0], op=OnnxWhere) 

624 self.inputs = [add_step] 

625 return self 

626 

627 def copy(self): 

628 """ 

629 Returns a copy of self (use of Identity node). 

630 """ 

631 OnnxIdentity = loadop('Identity') 

632 return OnnxVar(self, op=OnnxIdentity) 

633 

634 def flatten(self, axis=0): 

635 """ 

636 Flattens a matrix (see :epkg:`numpy:ndarray:flatten`). 

637 

638 :param axis: only flatten from axis to the end. 

639 :return: @see cl OnnxVar. 

640 """ 

641 OnnxFlatten, OnnxSqueeze = loadop('Flatten', 'Squeeze') 

642 fl = OnnxVar(self, op=OnnxFlatten, axis=axis) 

643 if axis == 0: 

644 return OnnxVar(fl, numpy.array([0], dtype=numpy.int64), 

645 op=OnnxSqueeze) 

646 return fl 

647 

648 

649class MultiOnnxVar: 

650 """ 

651 Class used to return multiple @see cl OnnxVar 

652 at the same time. 

653 """ 

654 

655 def __init__(self, *inputs, op=None, dtype=None, **kwargs): 

656 "constructor" 

657 logger.debug('MultiOnnxVar(%d in, dtype=%r, op=%r)', 

658 len(inputs), dtype, op) 

659 self.onxvar = OnnxVar(*inputs, op=op, dtype=None, **kwargs) 

660 self.alg_ = None 

661 

662 def _guess_dtype(self, dtype): 

663 "Guesses dtype when not specified." 

664 return self.onxvar._guess_dtype(dtype) 

665 

666 @property 

667 def inputs(self): 

668 "Returns `self.onxvar.inputs`." 

669 return self.onxvar.inputs 

670 

671 @property 

672 def onnx_op(self): 

673 "Returns `self.onxvar.onnx_op`." 

674 return self.onxvar.onnx_op 

675 

676 @property 

677 def onnx_op_kwargs(self): 

678 "Returns `self.onxvar.onnx_op_kwargs`." 

679 return self.onxvar.onnx_op_kwargs 

680 

681 def to_algebra(self, op_version=None): 

682 """ 

683 Converts the variable into an operator. 

684 """ 

685 if self.alg_ is None: 

686 logger.debug('MultiOnnxVar.to_algebra(op_version=%r)', 

687 op_version) 

688 new_inputs = [] 

689 for inp in self.inputs: 

690 if isinstance(inp, ( 

691 int, float, str, numpy.ndarray, numpy.int32, 

692 numpy.int64, numpy.float32, numpy.float64, 

693 numpy_bool, numpy_str, numpy.int8, numpy.uint8, 

694 numpy.int16, numpy.uint16, numpy.uint32, 

695 numpy.uint64)): 

696 new_inputs.append(inp) 

697 elif hasattr(inp, 'fit'): 

698 # scikit-learn models 

699 new_inputs.append(inp) 

700 else: 

701 new_inputs.append( 

702 inp.to_algebra(op_version=op_version)) 

703 

704 if self.onnx_op is None: 

705 if len(new_inputs) == 1: 

706 logger.debug('MultiOnnxVar.to_algebra:1:new_inputs[0]=%r', 

707 new_inputs[0]) 

708 self.alg_ = OnnxOperatorTuple(new_inputs[0]) 

709 else: 

710 logger.debug('MultiOnnxVar.to_algebra:2:new_inputs=%r', 

711 new_inputs) 

712 self.alg_ = OnnxOperatorTuple( 

713 new_inputs[0], *(new_inputs[1:])) 

714 else: 

715 logger.debug('MultiOnnxVar.to_algebra:%s:new_inputs=%r', 

716 self.onnx_op.__class__.__name__, new_inputs) 

717 res = self.onnx_op( # pylint: disable=E1102 

718 *new_inputs, op_version=op_version, **self.onnx_op_kwargs) 

719 self.alg_ = OnnxOperatorTuple(res) 

720 return self.alg_ 

721 

722 def __getitem__(self, index): 

723 """ 

724 Returns the ith elements. 

725 """ 

726 return OnnxVar(self, index=index, op=OnnxOperatorItem)