Coverage for mlprodict/onnx_tools/onnx2py_helper.py: 94%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

420 statements  

1""" 

2@file 

3@brief Functions which converts :epkg:`ONNX` object into 

4readable :epkg:`python` objects. 

5""" 

6import pprint 

7import warnings 

8import numpy 

9from scipy.sparse import coo_matrix 

10from onnx.defs import get_schema, get_function_ops, onnx_opset_version 

11from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE, TENSOR_TYPE_TO_NP_TYPE 

12from onnx import TensorProto, ValueInfoProto 

13from onnx.helper import make_tensor_type_proto 

14from onnx.numpy_helper import to_array, from_array as onnx_from_array 

15 

16 

17def to_bytes(val): 

18 """ 

19 Converts an array into protobuf and then into bytes. 

20 

21 :param val: array 

22 :return: bytes 

23 

24 .. exref:: 

25 :title: Converts an array into bytes (serialization) 

26 

27 Useful to serialize. 

28 

29 .. runpython:: 

30 :showcode: 

31 :warningout: DeprecationWarning 

32 

33 import numpy 

34 from mlprodict.onnx_tools.onnx2py_helper import to_bytes 

35 

36 data = numpy.array([[0, 1], [2, 3], [4, 5]], dtype=numpy.float32) 

37 pb = to_bytes(data) 

38 print(len(pb), data.size * data.itemsize, pb[:10]) 

39 """ 

40 if isinstance(val, numpy.ndarray): 

41 pb = from_array(val) 

42 else: 

43 pb = val # pragma: no cover 

44 return pb.SerializeToString() 

45 

46 

47def from_array(value, name=None): 

48 """ 

49 Converts an array into an ONNX tensor. 

50 

51 :param value: numpy array 

52 :return: ONNX tensor 

53 """ 

54 if isinstance(value, numpy.ndarray): 

55 try: 

56 pb = onnx_from_array(value, name=name) 

57 except NotImplementedError as e: # pragma: no cover 

58 if value.dtype == numpy.dtype('O'): 

59 pb = TensorProto() 

60 pb.data_type = TensorProto.STRING # pylint: disable=E1101 

61 if name is not None: 

62 pb.name = name 

63 pb.dims.extend(value.shape) # pylint: disable=E1101 

64 pb.string_data.extend( # pylint: disable=E1101 

65 list(map(lambda o: str(o).encode('utf-8'), value.ravel()))) 

66 else: 

67 raise NotImplementedError( 

68 "Unable to convert type %r (dtype=%r) into an ONNX tensor " 

69 "due to %r." % (type(value), value.dtype, e)) from e 

70 return pb 

71 if isinstance(value, TensorProto): # pragma: no cover 

72 return value 

73 raise NotImplementedError( # pragma: no cover 

74 "Unable to convert type %r into an ONNX tensor." % type(value)) 

75 

76 

77def from_bytes(b): 

78 """ 

79 Retrieves an array from bytes then protobuf. 

80 

81 :param b: bytes 

82 :return: array 

83 

84 .. exref:: 

85 :title: Converts bytes into an array (serialization) 

86 

87 Useful to deserialize. 

88 

89 .. runpython:: 

90 :showcode: 

91 :warningout: DeprecationWarning 

92 

93 import numpy 

94 from mlprodict.onnx_tools.onnx2py_helper import to_bytes, from_bytes 

95 

96 data = numpy.array([[0, 1], [2, 3], [4, 5]], dtype=numpy.float32) 

97 pb = to_bytes(data) 

98 data2 = from_bytes(pb) 

99 print(data2) 

100 """ 

101 if isinstance(b, bytes): 

102 pb = TensorProto() 

103 pb.ParseFromString(b) 

104 else: 

105 pb = b # pragma: no cover 

106 return to_array(pb) 

107 

108 

109def _numpy_array(data, dtype=None, copy=True): 

110 """ 

111 Single function to create an array. 

112 

113 @param data data 

114 @param dtype dtype 

115 @param copy copy 

116 @return numpy array 

117 """ 

118 if isinstance(data, numpy.ndarray): 

119 res = data 

120 else: 

121 res = numpy.array(data, dtype=dtype, copy=copy) 

122 return res 

123 

124 

125def _sparse_array(shape, data, indices, dtype=None, copy=True): 

126 """ 

127 Single function to create an sparse array 

128 (:epkg:`coo_matrix`). 

129 

130 @param shape shape 

131 @param data data 

132 @param indices indices 

133 @param dtype dtype 

134 @param copy copy 

135 @return :epkg:`coo_matrix` 

136 """ 

137 if len(shape) != 2: 

138 raise ValueError( # pragma: no cover 

139 "Only matrices are allowed or sparse matrices " 

140 "but shape is {}.".format(shape)) 

141 rows = numpy.array([i // shape[1] for i in indices]) 

142 cols = numpy.array([i % shape[1] for i in indices]) 

143 if isinstance(data, numpy.ndarray): 

144 res = coo_matrix((data, (rows, cols)), dtype=dtype) 

145 else: 

146 res = coo_matrix( # pragma: no cover 

147 (numpy.array(data, dtype=dtype, copy=copy), 

148 (rows, cols)), dtype=dtype) 

149 return res 

150 

151 

152def guess_numpy_type_from_string(name): 

153 """ 

154 Converts a string (such as `'float'`) into a 

155 numpy dtype. 

156 """ 

157 if name in ('float', 'float32'): 

158 return numpy.float32 

159 if name in ('double', 'float64'): 

160 return numpy.float64 

161 if name == 'float16': 

162 return numpy.float16 

163 if name == 'int64': 

164 return numpy.int64 

165 if name == 'int8': 

166 return numpy.int8 

167 if name == 'uint8': 

168 return numpy.uint8 

169 if name == 'int32': 

170 return numpy.int32 

171 if name == 'int16': 

172 return numpy.int16 

173 if name == 'bool': 

174 return numpy.bool_ 

175 if name == 'str': 

176 return numpy.str_ 

177 raise ValueError( # pragma: no cover 

178 "Unable to guess numpy dtype from %r." % name) 

179 

180 

181def guess_numpy_type_from_dtype(dt): 

182 """ 

183 Converts a string (such as `'dtype(float32)'`) into a 

184 numpy dtype. 

185 """ 

186 if dt in {numpy.int8, numpy.uint8, numpy.float16, numpy.float32, 

187 numpy.float64, numpy.int32, numpy.int64, numpy.int16, 

188 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_, 

189 numpy.uint64, bool, str, }: 

190 return dt 

191 if dt == numpy.dtype('float32'): 

192 return numpy.float32 

193 if dt == numpy.dtype('float64'): 

194 return numpy.float64 

195 if dt == numpy.dtype('int64'): 

196 return numpy.int64 

197 if dt == numpy.dtype('int8'): 

198 return numpy.int8 

199 if dt == numpy.dtype('uint8'): 

200 return numpy.uint8 

201 raise ValueError( # pragma: no cover 

202 "Unable to guess numpy dtype from %r." % dt) 

203 

204 

205def _elem_type_as_str(elem_type): 

206 if elem_type == TensorProto.FLOAT: # pylint: disable=E1101 

207 return 'float' 

208 if elem_type == TensorProto.BOOL: # pylint: disable=E1101 

209 return 'bool' 

210 if elem_type == TensorProto.DOUBLE: # pylint: disable=E1101 

211 return 'double' 

212 if elem_type == TensorProto.STRING: # pylint: disable=E1101 

213 return 'str' 

214 if elem_type == TensorProto.INT64: # pylint: disable=E1101 

215 return 'int64' 

216 if elem_type == TensorProto.INT32: # pylint: disable=E1101 

217 return 'int32' 

218 if elem_type == TensorProto.UINT32: # pylint: disable=E1101 

219 return 'uint32' 

220 if elem_type == TensorProto.UINT64: # pylint: disable=E1101 

221 return 'uint64' 

222 if elem_type == TensorProto.INT16: # pylint: disable=E1101 

223 return 'int16' 

224 if elem_type == TensorProto.UINT16: # pylint: disable=E1101 

225 return 'uint16' 

226 if elem_type == TensorProto.UINT8: # pylint: disable=E1101 

227 return 'uint8' 

228 if elem_type == TensorProto.INT8: # pylint: disable=E1101 

229 return 'int8' 

230 if elem_type == TensorProto.FLOAT16: # pylint: disable=E1101 

231 return 'float16' 

232 if elem_type == TensorProto.COMPLEX64: # pylint: disable=E1101 

233 return 'complex64' 

234 if elem_type == TensorProto.COMPLEX128: # pylint: disable=E1101 

235 return 'complex128' 

236 if elem_type == 0: # pylint: disable=E1101 

237 return 'unk' 

238 

239 # The following code should be refactored. 

240 selem = str(elem_type) 

241 

242 if selem.startswith("tensor_type"): 

243 this = elem_type.tensor_type 

244 et = _elem_type_as_str(this.elem_type) 

245 shape = this.shape 

246 dim = shape.dim 

247 dims = [d.dim_value for d in dim] 

248 if len(dims) == 0: 

249 dims = '?' 

250 return {'kind': 'tensor', 'elem': et, 'shape': shape} 

251 

252 if selem.startswith("optional_type"): 

253 this = elem_type.optional_type 

254 et = _elem_type_as_str(this.elem_type) 

255 shape = this.shape 

256 dim = shape.dim 

257 dims = [d.dim_value for d in dim] 

258 if len(dims) == 0: 

259 dims = '?' 

260 return {'kind': 'tensor', 'elem': et, 'shape': shape, 

261 'optional_type': True} 

262 

263 if selem.startswith("map_type"): 

264 this = elem_type.map_type 

265 kt = _elem_type_as_str(this.key_type) 

266 vt = _elem_type_as_str(this.value_type) 

267 return {'kind': 'map', 'key': kt, 'value': vt} 

268 

269 raise NotImplementedError( # pragma: no cover 

270 "elem_type '{}' is unknown\nfields:\n{}\n-----\n{}.".format( 

271 elem_type, pprint.pformat(dir(elem_type)), type(elem_type))) 

272 

273 

274def _to_array(var): 

275 try: 

276 data = to_array(var) 

277 except ValueError as e: # pragma: no cover 

278 dims = [d for d in var.dims] 

279 if var.data_type == 1 and var.float_data is not None: 

280 try: 

281 data = _numpy_array(var.float_data, dtype=numpy.float32, 

282 copy=False).reshape(dims) 

283 except ValueError: 

284 data = _numpy_array(to_array(var)) 

285 elif var.data_type == 2 and var.uint8_data is not None: 

286 data = _numpy_array(var.uint8_data, dtype=numpy.uint8, 

287 copy=False).reshape(dims) 

288 elif var.data_type == 3 and var.int8_data is not None: 

289 data = _numpy_array(var.int8_data, dtype=numpy.int8, 

290 copy=False).reshape(dims) 

291 elif var.data_type == 4 and var.uint16_data is not None: 

292 data = _numpy_array(var.uint16_data, dtype=numpy.uint16, 

293 copy=False).reshape(dims) 

294 elif var.data_type == 5 and var.int16_data is not None: 

295 data = _numpy_array(var.int16_data, dtype=numpy.int16, 

296 copy=False).reshape(dims) 

297 elif var.data_type == 6 and var.int32_data is not None: 

298 data = _numpy_array(var.int32_data, dtype=numpy.int32, 

299 copy=False).reshape(dims) 

300 elif var.data_type == 7 and var.int64_data is not None: 

301 data = _numpy_array(var.int64_data, dtype=numpy.int64, 

302 copy=False).reshape(dims) 

303 elif var.data_type == 11 and var.double_data is not None: 

304 try: 

305 data = _numpy_array(var.double_data, dtype=numpy.float64, 

306 copy=False).reshape(dims) 

307 except ValueError: 

308 data = _numpy_array(to_array(var)) 

309 elif var.data_type == 16 and var.float16_data is not None: 

310 data = _numpy_array(var.float16_data, dtype=numpy.float16, 

311 copy=False).reshape(dims) 

312 else: 

313 raise NotImplementedError( 

314 "Iniatilizer {} cannot be converted into a dictionary.".format(var)) from e 

315 return data 

316 

317 

318def _var_as_dict(var): 

319 """ 

320 Converts a protobuf object into something readable. 

321 The current implementation relies on :epkg:`json`. 

322 That's not the most efficient way. 

323 """ 

324 if hasattr(var, 'type') and str(var.type) != '': 

325 # variable 

326 if var.type is not None: 

327 if hasattr(var, 'sparse_tensor') and var.type == 11: 

328 # sparse tensor 

329 t = var.sparse_tensor 

330 values = _var_as_dict(t.values) 

331 dims = list(t.dims) 

332 dtype = dict(kind='sparse_tensor', shape=tuple(dims), elem=1) 

333 elif (hasattr(var.type, 'tensor_type') and 

334 var.type.tensor_type.elem_type > 0): 

335 t = var.type.tensor_type 

336 elem_type = _elem_type_as_str(t.elem_type) 

337 shape = t.shape 

338 dim = shape.dim 

339 dims = [d.dim_value for d in dim] 

340 if len(dims) == 0: 

341 dims = '?' 

342 dtype = dict(kind='tensor', elem=elem_type, 

343 shape=tuple(dims)) 

344 elif (hasattr(var.type, 'optional_type') and 

345 var.type.tensor_type.elem_type > 0): 

346 t = var.type.optional_type 

347 elem_type = _elem_type_as_str(t.elem_type) 

348 shape = t.shape 

349 dim = shape.dim 

350 dims = [d.dim_value for d in dim] 

351 if len(dims) == 0: 

352 dims = '?' 

353 dtype = dict(kind='tensor', elem=elem_type, 

354 shape=tuple(dims), optional_type=True) 

355 elif (hasattr(var.type, 'real') and var.type.real == 5 and 

356 hasattr(var, 'g')): 

357 dtype = dict(kind='graph', elem=var.type.real) 

358 elif (hasattr(var.type, 'real') and var.type.real == 4 and 

359 hasattr(var, 't')): 

360 dtype = dict(kind='tensor', elem=var.type.real) 

361 elif hasattr(var.type, 'real'): 

362 dtype = dict(kind='real', elem=var.type.real) 

363 elif (hasattr(var.type, "sequence_type") and 

364 var.type.sequence_type is not None and 

365 str(var.type.sequence_type.elem_type) != ''): 

366 t = var.type.sequence_type 

367 elem_type = _elem_type_as_str(t.elem_type) 

368 dtype = dict(kind='sequence', elem=elem_type) 

369 elif (hasattr(var.type, "map_type") and 

370 var.type.map_type is not None and 

371 str(var.type.map_type.key_type) != '' and 

372 str(var.type.map_type.value_type) != ''): 

373 t = var.type.map_type 

374 key_type = _elem_type_as_str(t.key_type) 

375 value_type = _elem_type_as_str(t.value_type) 

376 dtype = dict(kind='map', key=key_type, value=value_type) 

377 elif (hasattr(var.type, 'tensor_type') and 

378 var.type.tensor_type.elem_type == 0): 

379 if hasattr(var.type, 'optional_type'): 

380 optional = var.type.optional_type 

381 else: 

382 optional = None 

383 t = var.type.tensor_type 

384 elem_type = _elem_type_as_str(t.elem_type) 

385 shape = t.shape 

386 dim = shape.dim 

387 dims = [d.dim_value for d in dim] 

388 if len(dims) == 0: 

389 dims = '?' 

390 dtype = dict(kind='tensor', elem=elem_type, 

391 shape=tuple(dims)) 

392 if optional is not None: 

393 dtype['optional'] = _var_as_dict(optional) 

394 else: 

395 raise NotImplementedError( # pragma: no cover 

396 "Unable to convert a type into a dictionary for '{}'. " 

397 "Available fields: {}.".format( 

398 var.type, pprint.pformat(dir(var.type)))) 

399 else: 

400 raise NotImplementedError( # pragma: no cover 

401 "Unable to convert variable into a dictionary for '{}'. " 

402 "Available fields: {}.".format( 

403 var, pprint.pformat(dir(var.type)))) 

404 

405 res = dict(name=var.name, type=dtype) 

406 

407 if (hasattr(var, 'sparse_tensor') and dtype.get('elem', None) == 1 and 

408 dtype['kind'] == 'sparse_tensor'): 

409 # sparse matrix 

410 t = var.sparse_tensor 

411 try: 

412 values = _var_as_dict(t.values) 

413 except NotImplementedError as e: # pragma: no cover 

414 raise NotImplementedError( 

415 "Issue with\n{}\n---".format(var)) from e 

416 indices = _var_as_dict(t.indices) 

417 res['value'] = _sparse_array( 

418 dtype['shape'], values['value'], indices['value'], dtype=numpy.float32) 

419 elif hasattr(var, 'floats') and dtype.get('elem', None) == 6: 

420 res['value'] = _numpy_array(var.floats, dtype=numpy.float32) 

421 elif hasattr(var, 'strings') and dtype.get('elem', None) == 8: 

422 res['value'] = _numpy_array(var.strings) 

423 elif hasattr(var, 'ints') and dtype.get('elem', None) == 7: 

424 res['value'] = _numpy_array(var.ints) 

425 elif hasattr(var, 'f') and dtype.get('elem', None) == 1: 

426 res['value'] = var.f 

427 elif hasattr(var, 's') and dtype.get('elem', None) == 3: 

428 res['value'] = var.s 

429 elif hasattr(var, 'i') and dtype.get('elem', None) == 2: 

430 res['value'] = var.i 

431 elif hasattr(var, 'g') and dtype.get('elem', None) == 5: 

432 res['value'] = var.g 

433 elif hasattr(var, 't') and dtype.get('elem', None) == 4: 

434 ts = _var_as_dict(var.t) 

435 res['value'] = ts['value'] 

436 elif hasattr(var, 'sparse_tensor') and dtype.get('elem', None) == 11: 

437 ts = _var_as_dict(var.sparse_tensor) 

438 res['value'] = ts['value'] 

439 elif "'value'" in str(var): 

440 warnings.warn("No value: {} -- {}".format( # pragma: no cover 

441 dtype, str(var).replace("\n", "").replace(" ", ""))) 

442 return res 

443 

444 if hasattr(var, 'op_type'): 

445 if hasattr(var, 'attribute'): 

446 atts = {} 

447 for att in var.attribute: 

448 atts[att.name] = _var_as_dict(att) 

449 return dict(name=var.name, op_type=var.op_type, 

450 domain=var.domain, atts=atts) 

451 if hasattr(var, 'dims') and len(var.dims) > 0: 

452 # initializer 

453 data = _to_array(var) 

454 return dict(name=var.name, value=data) 

455 if hasattr(var, 'data_type') and var.data_type > 0: 

456 data = _to_array(var) 

457 return dict(name=var.name, value=data) 

458 if isinstance(var, str): 

459 return dict(name=var) 

460 if str(var) == '': 

461 return None 

462 raise NotImplementedError( # pragma: no cover 

463 "Unable to guess which object it is type is %r value is %r." 

464 "" % (type(var), str(var))) 

465 

466 

467def get_dtype_shape(obj): 

468 """ 

469 Returns the shape of a tensor. 

470 

471 :param obj: onnx object 

472 :return: `(dtype, shape)` or `(None, None)` if not applicable 

473 """ 

474 if not hasattr(obj, 'type'): 

475 return None 

476 t = obj.type 

477 if not hasattr(t, 'tensor_type'): 

478 return None 

479 t = t.tensor_type 

480 dtype = t.elem_type 

481 if not hasattr(t, 'shape'): 

482 return dtype, None 

483 shape = t.shape 

484 ds = [] 

485 for dim in shape.dim: 

486 d = dim.dim_value 

487 s = dim.dim_param 

488 if d == 0: 

489 if s == '': 

490 ds.append(None) 

491 else: 

492 ds.append(s) 

493 else: 

494 ds.append(d) 

495 return dtype, tuple(ds) 

496 

497 

498def onnx_model_opsets(onnx_model): 

499 """ 

500 Extracts opsets in a dictionary. 

501 

502 :param onnx_model: ONNX graph 

503 :return: dictionary `{domain: version}` 

504 """ 

505 res = {} 

506 for oimp in onnx_model.opset_import: 

507 res[oimp.domain] = oimp.version 

508 return res 

509 

510 

511def _type_to_string(dtype): 

512 """ 

513 Converts a type into a readable string. 

514 """ 

515 if not isinstance(dtype, dict): 

516 dtype_ = _var_as_dict(dtype) # pragma: no cover 

517 else: 

518 dtype_ = dtype 

519 if dtype_["kind"] == 'tensor': 

520 return "{0}({1})".format(dtype_['elem'], dtype_['shape']) 

521 if dtype_['kind'] == 'sequence': 

522 return "[{0}]".format(_type_to_string(dtype_['elem'])) 

523 if dtype_["kind"] == 'map': 

524 return "{{{0}, {1}}}".format(dtype_['key'], dtype_['value']) 

525 raise NotImplementedError( # pragma: no cover 

526 "Unable to convert into string {} or {}.".format(dtype, dtype_)) 

527 

528 

529def numpy_min(x): 

530 """ 

531 Returns the minimum of an array. 

532 Deals with text as well. 

533 """ 

534 try: 

535 if hasattr(x, 'todense'): 

536 x = x.todense() 

537 if x.dtype.kind not in 'cUC': 

538 return x.min() 

539 try: # pragma: no cover 

540 x = x.ravel() 

541 except AttributeError: # pragma: no cover 

542 pass 

543 keep = list(filter(lambda s: isinstance(s, str), x)) 

544 if len(keep) == 0: # pragma: no cover 

545 return numpy.nan 

546 keep.sort() 

547 val = keep[0] 

548 if len(val) > 10: # pragma: no cover 

549 val = val[:10] + '...' 

550 return "%r" % val 

551 except (ValueError, TypeError): # pragma: no cover 

552 return '?' 

553 

554 

555def numpy_max(x): 

556 """ 

557 Returns the maximum of an array. 

558 Deals with text as well. 

559 """ 

560 try: 

561 if hasattr(x, 'todense'): 

562 x = x.todense() 

563 if x.dtype.kind not in 'cUC': 

564 return x.max() 

565 try: # pragma: no cover 

566 x = x.ravel() 

567 except AttributeError: # pragma: no cover 

568 pass 

569 keep = list(filter(lambda s: isinstance(s, str), x)) 

570 if len(keep) == 0: # pragma: no cover 

571 return numpy.nan 

572 keep.sort() 

573 val = keep[-1] 

574 if len(val) > 10: # pragma: no cover 

575 val = val[:10] + '...' 

576 return "%r" % val 

577 except (ValueError, TypeError): # pragma: no cover 

578 return '?' 

579 

580 

581def guess_proto_dtype(dtype): 

582 """ 

583 Guesses the ONNX dtype given a numpy dtype. 

584 

585 :param dtype: numpy dtype 

586 :return: proto type 

587 """ 

588 if dtype == numpy.float32: 

589 return TensorProto.FLOAT # pylint: disable=E1101 

590 if dtype == numpy.float64: 

591 return TensorProto.DOUBLE # pylint: disable=E1101 

592 if dtype == numpy.int64: 

593 return TensorProto.INT64 # pylint: disable=E1101 

594 if dtype == numpy.int32: 

595 return TensorProto.INT32 # pylint: disable=E1101 

596 if dtype == numpy.int16: 

597 return TensorProto.INT16 # pylint: disable=E1101 

598 if dtype == numpy.int8: 

599 return TensorProto.INT8 # pylint: disable=E1101 

600 if dtype == numpy.uint64: 

601 return TensorProto.UINT64 # pylint: disable=E1101 

602 if dtype == numpy.uint32: 

603 return TensorProto.UINT32 # pylint: disable=E1101 

604 if dtype == numpy.uint16: 

605 return TensorProto.UINT16 # pylint: disable=E1101 

606 if dtype == numpy.uint8: 

607 return TensorProto.UINT8 # pylint: disable=E1101 

608 if dtype == numpy.float16: 

609 return TensorProto.FLOAT16 # pylint: disable=E1101 

610 if dtype in (bool, numpy.bool_): 

611 return TensorProto.BOOL # pylint: disable=E1101 

612 if dtype in (str, numpy.str_): 

613 return TensorProto.STRING # pylint: disable=E1101 

614 raise RuntimeError( 

615 "Unable to guess type for dtype={}.".format(dtype)) # pragma: no cover 

616 

617 

618def guess_proto_dtype_name(onnx_dtype): 

619 """ 

620 Returns a string equivalent to `onnx_dtype`. 

621 

622 :param dtype: onnx dtype 

623 :return: proto type 

624 """ 

625 if onnx_dtype == TensorProto.FLOAT: # pylint: disable=E1101 

626 return "TensorProto.FLOAT" 

627 if onnx_dtype == TensorProto.DOUBLE: # pylint: disable=E1101 

628 return "TensorProto.DOUBLE" 

629 if onnx_dtype == TensorProto.INT64: # pylint: disable=E1101 

630 return "TensorProto.INT64" 

631 if onnx_dtype == TensorProto.INT32: # pylint: disable=E1101 

632 return "TensorProto.INT32" 

633 if onnx_dtype == TensorProto.INT16: # pylint: disable=E1101 

634 return "TensorProto.INT16" 

635 if onnx_dtype == TensorProto.UINT8: # pylint: disable=E1101 

636 return "TensorProto.UINT8" 

637 if onnx_dtype == TensorProto.FLOAT16: # pylint: disable=E1101 

638 return "TensorProto.FLOAT16" 

639 if onnx_dtype == TensorProto.BOOL: # pylint: disable=E1101 

640 return "TensorProto.BOOL" 

641 if onnx_dtype == TensorProto.STRING: # pylint: disable=E1101 

642 return "TensorProto.STRING" 

643 raise RuntimeError( # pragma: no cover 

644 "Unable to guess type for dtype={}.".format(onnx_dtype)) 

645 

646 

647def guess_dtype(proto_type): 

648 """ 

649 Converts a proto type into a :epkg:`numpy` type. 

650 

651 :param proto_type: example ``onnx.TensorProto.FLOAT`` 

652 :return: :epkg:`numpy` dtype 

653 """ 

654 if proto_type == TensorProto.FLOAT: # pylint: disable=E1101 

655 return numpy.float32 

656 if proto_type == TensorProto.BOOL: # pylint: disable=E1101 

657 return numpy.bool_ 

658 if proto_type == TensorProto.DOUBLE: # pylint: disable=E1101 

659 return numpy.float64 

660 if proto_type == TensorProto.STRING: # pylint: disable=E1101 

661 return numpy.str_ 

662 if proto_type == TensorProto.INT64: # pylint: disable=E1101 

663 return numpy.int64 

664 if proto_type == TensorProto.INT32: # pylint: disable=E1101 

665 return numpy.int32 

666 if proto_type == TensorProto.INT8: # pylint: disable=E1101 

667 return numpy.int8 

668 if proto_type == TensorProto.INT16: # pylint: disable=E1101 

669 return numpy.int16 

670 if proto_type == TensorProto.UINT64: # pylint: disable=E1101 

671 return numpy.uint64 

672 if proto_type == TensorProto.UINT32: # pylint: disable=E1101 

673 return numpy.uint32 

674 if proto_type == TensorProto.UINT8: # pylint: disable=E1101 

675 return numpy.uint8 

676 if proto_type == TensorProto.UINT16: # pylint: disable=E1101 

677 return numpy.uint16 

678 if proto_type == TensorProto.FLOAT16: # pylint: disable=E1101 

679 return numpy.float16 

680 raise ValueError( 

681 "Unable to convert proto_type {} to numpy type.".format( 

682 proto_type)) 

683 

684 

685def to_skl2onnx_type(name, elem_type, shape): 

686 """ 

687 Converts *name*, *elem_type*, *shape* into a 

688 :epkg:`sklearn-onnx` type. 

689 

690 :param name: string 

691 :param elem_type: tensor of elements of this type 

692 :param shape: expected shape 

693 :return: data type 

694 """ 

695 from skl2onnx.common.data_types import _guess_numpy_type # delayed 

696 elem = guess_numpy_type_from_string(elem_type) 

697 shape = list(None if d == 0 else d for d in shape) 

698 return (name, _guess_numpy_type(elem, shape)) 

699 

700 

701def from_pb(obj): 

702 """ 

703 Extracts tensor description from a protobuf. 

704 

705 :param obj: initializer, tensor 

706 :return: (name, type, shape) 

707 """ 

708 def get_dim(d): 

709 r = d.dim_value 

710 if "dim_param" in str(d): 

711 return None 

712 if r == 0: 

713 # dim_value is 0 when it is 0 or undefined 

714 return 0 if "0" in str(d) else None 

715 return r 

716 

717 def get_shape(tt): 

718 return [get_dim(tt.shape.dim[i]) 

719 for i in range(len(tt.shape.dim))] 

720 

721 if hasattr(obj, 'extend'): 

722 return [from_pb(o) for o in obj] 

723 

724 name = obj.name 

725 if obj.type.tensor_type: 

726 tt = obj.type.tensor_type 

727 elem = tt.elem_type 

728 shape = get_shape(tt) 

729 if elem not in TENSOR_TYPE_TO_NP_TYPE: 

730 raise NotImplementedError( 

731 "Unsupported type '{}' (elem_type={}).".format( 

732 type(obj.type.tensor_type), elem)) 

733 ty = TENSOR_TYPE_TO_NP_TYPE[elem].type 

734 else: 

735 raise NotImplementedError("Unsupported type '{}' as " 

736 "a string ({}).".format( 

737 type(obj), obj)) 

738 

739 return (name, ty, shape) 

740 

741 

742def numpy_type_prototype(dtype): 

743 """ 

744 Converts a numpy dtyp into a TensorProto dtype. 

745 

746 :param dtype: dtype 

747 :return: proto dtype 

748 """ 

749 if dtype in NP_TYPE_TO_TENSOR_TYPE: 

750 return NP_TYPE_TO_TENSOR_TYPE[dtype] 

751 dt = numpy.dtype(dtype) 

752 if dt in NP_TYPE_TO_TENSOR_TYPE: 

753 return NP_TYPE_TO_TENSOR_TYPE[dt] 

754 raise ValueError( # pragma: no cover 

755 "Unable to convert dtype %r into ProtoType." % dtype) 

756 

757 

758def make_value_info(name, dtype, shape): 

759 """ 

760 Converts a variable defined by its name, type and shape 

761 into `onnx.ValueInfoProto`. 

762 

763 :return: instance of `onnx.ValueInfoProto` 

764 """ 

765 value_info = ValueInfoProto() 

766 value_info.name = name 

767 tensor_type_proto = make_tensor_type_proto( 

768 numpy_type_prototype(dtype), shape) 

769 value_info.type.CopyFrom(tensor_type_proto) # pylint: disable=E1101 

770 return value_info 

771 

772 

773_get_onnx_function_cache = None 

774 

775 

776def _get_onnx_function(): 

777 """ 

778 Returns the list of functions defined in ONNX package. 

779 """ 

780 global _get_onnx_function_cache # pylint: disable=W0603 

781 if _get_onnx_function_cache is None: 

782 _get_onnx_function_cache = {} 

783 fcts = get_function_ops() 

784 for fct in fcts: 

785 key = fct.domain, fct.name 

786 if key in _get_onnx_function_cache: 

787 raise RuntimeError( # pragma: no cover 

788 "Function %r is already registered." % (key, )) 

789 _get_onnx_function_cache[key] = fct 

790 return _get_onnx_function_cache 

791 

792 

793def get_onnx_schema(opname, domain='', opset=None, load_function=False): 

794 """ 

795 Returns the operator schema for a specific operator. 

796 

797 :param domain: operator domain 

798 :param opname: operator name 

799 :param opset: opset or version, None for the latest 

800 :param load_function: loads the function, if True, the function 

801 looks into the list of function if one of them has the same name, 

802 opset must be None in that case 

803 :return: :epkg:`OpSchema` 

804 """ 

805 if load_function: 

806 if opset is not None: 

807 raise ValueError( 

808 "opset must be None if load_function is True for " 

809 "operator (%r,%r)." % (domain, opname)) 

810 fcts = _get_onnx_function() 

811 key = domain, opname 

812 if key in fcts: 

813 return fcts[key] 

814 if opset is None: 

815 opset = onnx_opset_version() 

816 return get_schema(opname, opset, domain) 

817 if opset is None: 

818 opset = onnx_opset_version() 

819 return get_schema(opname, opset, domain)