Coverage for mlprodict/onnx_tools/onnx2py_helper.py: 94%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Functions which converts :epkg:`ONNX` object into
4readable :epkg:`python` objects.
5"""
6import pprint
7import warnings
8import numpy
9from scipy.sparse import coo_matrix
10from onnx.defs import get_schema, get_function_ops, onnx_opset_version
11from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE, TENSOR_TYPE_TO_NP_TYPE
12from onnx import TensorProto, ValueInfoProto
13from onnx.helper import make_tensor_type_proto
14from onnx.numpy_helper import to_array, from_array as onnx_from_array
17def to_bytes(val):
18 """
19 Converts an array into protobuf and then into bytes.
21 :param val: array
22 :return: bytes
24 .. exref::
25 :title: Converts an array into bytes (serialization)
27 Useful to serialize.
29 .. runpython::
30 :showcode:
31 :warningout: DeprecationWarning
33 import numpy
34 from mlprodict.onnx_tools.onnx2py_helper import to_bytes
36 data = numpy.array([[0, 1], [2, 3], [4, 5]], dtype=numpy.float32)
37 pb = to_bytes(data)
38 print(len(pb), data.size * data.itemsize, pb[:10])
39 """
40 if isinstance(val, numpy.ndarray):
41 pb = from_array(val)
42 else:
43 pb = val # pragma: no cover
44 return pb.SerializeToString()
47def from_array(value, name=None):
48 """
49 Converts an array into an ONNX tensor.
51 :param value: numpy array
52 :return: ONNX tensor
53 """
54 if isinstance(value, numpy.ndarray):
55 try:
56 pb = onnx_from_array(value, name=name)
57 except NotImplementedError as e: # pragma: no cover
58 if value.dtype == numpy.dtype('O'):
59 pb = TensorProto()
60 pb.data_type = TensorProto.STRING # pylint: disable=E1101
61 if name is not None:
62 pb.name = name
63 pb.dims.extend(value.shape) # pylint: disable=E1101
64 pb.string_data.extend( # pylint: disable=E1101
65 list(map(lambda o: str(o).encode('utf-8'), value.ravel())))
66 else:
67 raise NotImplementedError(
68 "Unable to convert type %r (dtype=%r) into an ONNX tensor "
69 "due to %r." % (type(value), value.dtype, e)) from e
70 return pb
71 if isinstance(value, TensorProto): # pragma: no cover
72 return value
73 raise NotImplementedError( # pragma: no cover
74 "Unable to convert type %r into an ONNX tensor." % type(value))
77def from_bytes(b):
78 """
79 Retrieves an array from bytes then protobuf.
81 :param b: bytes
82 :return: array
84 .. exref::
85 :title: Converts bytes into an array (serialization)
87 Useful to deserialize.
89 .. runpython::
90 :showcode:
91 :warningout: DeprecationWarning
93 import numpy
94 from mlprodict.onnx_tools.onnx2py_helper import to_bytes, from_bytes
96 data = numpy.array([[0, 1], [2, 3], [4, 5]], dtype=numpy.float32)
97 pb = to_bytes(data)
98 data2 = from_bytes(pb)
99 print(data2)
100 """
101 if isinstance(b, bytes):
102 pb = TensorProto()
103 pb.ParseFromString(b)
104 else:
105 pb = b # pragma: no cover
106 return to_array(pb)
109def _numpy_array(data, dtype=None, copy=True):
110 """
111 Single function to create an array.
113 @param data data
114 @param dtype dtype
115 @param copy copy
116 @return numpy array
117 """
118 if isinstance(data, numpy.ndarray):
119 res = data
120 else:
121 res = numpy.array(data, dtype=dtype, copy=copy)
122 return res
125def _sparse_array(shape, data, indices, dtype=None, copy=True):
126 """
127 Single function to create an sparse array
128 (:epkg:`coo_matrix`).
130 @param shape shape
131 @param data data
132 @param indices indices
133 @param dtype dtype
134 @param copy copy
135 @return :epkg:`coo_matrix`
136 """
137 if len(shape) != 2:
138 raise ValueError( # pragma: no cover
139 "Only matrices are allowed or sparse matrices "
140 "but shape is {}.".format(shape))
141 rows = numpy.array([i // shape[1] for i in indices])
142 cols = numpy.array([i % shape[1] for i in indices])
143 if isinstance(data, numpy.ndarray):
144 res = coo_matrix((data, (rows, cols)), dtype=dtype)
145 else:
146 res = coo_matrix( # pragma: no cover
147 (numpy.array(data, dtype=dtype, copy=copy),
148 (rows, cols)), dtype=dtype)
149 return res
152def guess_numpy_type_from_string(name):
153 """
154 Converts a string (such as `'float'`) into a
155 numpy dtype.
156 """
157 if name in ('float', 'float32'):
158 return numpy.float32
159 if name in ('double', 'float64'):
160 return numpy.float64
161 if name == 'float16':
162 return numpy.float16
163 if name == 'int64':
164 return numpy.int64
165 if name == 'int8':
166 return numpy.int8
167 if name == 'uint8':
168 return numpy.uint8
169 if name == 'int32':
170 return numpy.int32
171 if name == 'int16':
172 return numpy.int16
173 if name == 'bool':
174 return numpy.bool_
175 if name == 'str':
176 return numpy.str_
177 raise ValueError( # pragma: no cover
178 "Unable to guess numpy dtype from %r." % name)
181def guess_numpy_type_from_dtype(dt):
182 """
183 Converts a string (such as `'dtype(float32)'`) into a
184 numpy dtype.
185 """
186 if dt in {numpy.int8, numpy.uint8, numpy.float16, numpy.float32,
187 numpy.float64, numpy.int32, numpy.int64, numpy.int16,
188 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_,
189 numpy.uint64, bool, str, }:
190 return dt
191 if dt == numpy.dtype('float32'):
192 return numpy.float32
193 if dt == numpy.dtype('float64'):
194 return numpy.float64
195 if dt == numpy.dtype('int64'):
196 return numpy.int64
197 if dt == numpy.dtype('int8'):
198 return numpy.int8
199 if dt == numpy.dtype('uint8'):
200 return numpy.uint8
201 raise ValueError( # pragma: no cover
202 "Unable to guess numpy dtype from %r." % dt)
205def _elem_type_as_str(elem_type):
206 if elem_type == TensorProto.FLOAT: # pylint: disable=E1101
207 return 'float'
208 if elem_type == TensorProto.BOOL: # pylint: disable=E1101
209 return 'bool'
210 if elem_type == TensorProto.DOUBLE: # pylint: disable=E1101
211 return 'double'
212 if elem_type == TensorProto.STRING: # pylint: disable=E1101
213 return 'str'
214 if elem_type == TensorProto.INT64: # pylint: disable=E1101
215 return 'int64'
216 if elem_type == TensorProto.INT32: # pylint: disable=E1101
217 return 'int32'
218 if elem_type == TensorProto.UINT32: # pylint: disable=E1101
219 return 'uint32'
220 if elem_type == TensorProto.UINT64: # pylint: disable=E1101
221 return 'uint64'
222 if elem_type == TensorProto.INT16: # pylint: disable=E1101
223 return 'int16'
224 if elem_type == TensorProto.UINT16: # pylint: disable=E1101
225 return 'uint16'
226 if elem_type == TensorProto.UINT8: # pylint: disable=E1101
227 return 'uint8'
228 if elem_type == TensorProto.INT8: # pylint: disable=E1101
229 return 'int8'
230 if elem_type == TensorProto.FLOAT16: # pylint: disable=E1101
231 return 'float16'
232 if elem_type == TensorProto.COMPLEX64: # pylint: disable=E1101
233 return 'complex64'
234 if elem_type == TensorProto.COMPLEX128: # pylint: disable=E1101
235 return 'complex128'
236 if elem_type == 0: # pylint: disable=E1101
237 return 'unk'
239 # The following code should be refactored.
240 selem = str(elem_type)
242 if selem.startswith("tensor_type"):
243 this = elem_type.tensor_type
244 et = _elem_type_as_str(this.elem_type)
245 shape = this.shape
246 dim = shape.dim
247 dims = [d.dim_value for d in dim]
248 if len(dims) == 0:
249 dims = '?'
250 return {'kind': 'tensor', 'elem': et, 'shape': shape}
252 if selem.startswith("optional_type"):
253 this = elem_type.optional_type
254 et = _elem_type_as_str(this.elem_type)
255 shape = this.shape
256 dim = shape.dim
257 dims = [d.dim_value for d in dim]
258 if len(dims) == 0:
259 dims = '?'
260 return {'kind': 'tensor', 'elem': et, 'shape': shape,
261 'optional_type': True}
263 if selem.startswith("map_type"):
264 this = elem_type.map_type
265 kt = _elem_type_as_str(this.key_type)
266 vt = _elem_type_as_str(this.value_type)
267 return {'kind': 'map', 'key': kt, 'value': vt}
269 raise NotImplementedError( # pragma: no cover
270 "elem_type '{}' is unknown\nfields:\n{}\n-----\n{}.".format(
271 elem_type, pprint.pformat(dir(elem_type)), type(elem_type)))
274def _to_array(var):
275 try:
276 data = to_array(var)
277 except ValueError as e: # pragma: no cover
278 dims = [d for d in var.dims]
279 if var.data_type == 1 and var.float_data is not None:
280 try:
281 data = _numpy_array(var.float_data, dtype=numpy.float32,
282 copy=False).reshape(dims)
283 except ValueError:
284 data = _numpy_array(to_array(var))
285 elif var.data_type == 2 and var.uint8_data is not None:
286 data = _numpy_array(var.uint8_data, dtype=numpy.uint8,
287 copy=False).reshape(dims)
288 elif var.data_type == 3 and var.int8_data is not None:
289 data = _numpy_array(var.int8_data, dtype=numpy.int8,
290 copy=False).reshape(dims)
291 elif var.data_type == 4 and var.uint16_data is not None:
292 data = _numpy_array(var.uint16_data, dtype=numpy.uint16,
293 copy=False).reshape(dims)
294 elif var.data_type == 5 and var.int16_data is not None:
295 data = _numpy_array(var.int16_data, dtype=numpy.int16,
296 copy=False).reshape(dims)
297 elif var.data_type == 6 and var.int32_data is not None:
298 data = _numpy_array(var.int32_data, dtype=numpy.int32,
299 copy=False).reshape(dims)
300 elif var.data_type == 7 and var.int64_data is not None:
301 data = _numpy_array(var.int64_data, dtype=numpy.int64,
302 copy=False).reshape(dims)
303 elif var.data_type == 11 and var.double_data is not None:
304 try:
305 data = _numpy_array(var.double_data, dtype=numpy.float64,
306 copy=False).reshape(dims)
307 except ValueError:
308 data = _numpy_array(to_array(var))
309 elif var.data_type == 16 and var.float16_data is not None:
310 data = _numpy_array(var.float16_data, dtype=numpy.float16,
311 copy=False).reshape(dims)
312 else:
313 raise NotImplementedError(
314 "Iniatilizer {} cannot be converted into a dictionary.".format(var)) from e
315 return data
318def _var_as_dict(var):
319 """
320 Converts a protobuf object into something readable.
321 The current implementation relies on :epkg:`json`.
322 That's not the most efficient way.
323 """
324 if hasattr(var, 'type') and str(var.type) != '':
325 # variable
326 if var.type is not None:
327 if hasattr(var, 'sparse_tensor') and var.type == 11:
328 # sparse tensor
329 t = var.sparse_tensor
330 values = _var_as_dict(t.values)
331 dims = list(t.dims)
332 dtype = dict(kind='sparse_tensor', shape=tuple(dims), elem=1)
333 elif (hasattr(var.type, 'tensor_type') and
334 var.type.tensor_type.elem_type > 0):
335 t = var.type.tensor_type
336 elem_type = _elem_type_as_str(t.elem_type)
337 shape = t.shape
338 dim = shape.dim
339 dims = [d.dim_value for d in dim]
340 if len(dims) == 0:
341 dims = '?'
342 dtype = dict(kind='tensor', elem=elem_type,
343 shape=tuple(dims))
344 elif (hasattr(var.type, 'optional_type') and
345 var.type.tensor_type.elem_type > 0):
346 t = var.type.optional_type
347 elem_type = _elem_type_as_str(t.elem_type)
348 shape = t.shape
349 dim = shape.dim
350 dims = [d.dim_value for d in dim]
351 if len(dims) == 0:
352 dims = '?'
353 dtype = dict(kind='tensor', elem=elem_type,
354 shape=tuple(dims), optional_type=True)
355 elif (hasattr(var.type, 'real') and var.type.real == 5 and
356 hasattr(var, 'g')):
357 dtype = dict(kind='graph', elem=var.type.real)
358 elif (hasattr(var.type, 'real') and var.type.real == 4 and
359 hasattr(var, 't')):
360 dtype = dict(kind='tensor', elem=var.type.real)
361 elif hasattr(var.type, 'real'):
362 dtype = dict(kind='real', elem=var.type.real)
363 elif (hasattr(var.type, "sequence_type") and
364 var.type.sequence_type is not None and
365 str(var.type.sequence_type.elem_type) != ''):
366 t = var.type.sequence_type
367 elem_type = _elem_type_as_str(t.elem_type)
368 dtype = dict(kind='sequence', elem=elem_type)
369 elif (hasattr(var.type, "map_type") and
370 var.type.map_type is not None and
371 str(var.type.map_type.key_type) != '' and
372 str(var.type.map_type.value_type) != ''):
373 t = var.type.map_type
374 key_type = _elem_type_as_str(t.key_type)
375 value_type = _elem_type_as_str(t.value_type)
376 dtype = dict(kind='map', key=key_type, value=value_type)
377 elif (hasattr(var.type, 'tensor_type') and
378 var.type.tensor_type.elem_type == 0):
379 if hasattr(var.type, 'optional_type'):
380 optional = var.type.optional_type
381 else:
382 optional = None
383 t = var.type.tensor_type
384 elem_type = _elem_type_as_str(t.elem_type)
385 shape = t.shape
386 dim = shape.dim
387 dims = [d.dim_value for d in dim]
388 if len(dims) == 0:
389 dims = '?'
390 dtype = dict(kind='tensor', elem=elem_type,
391 shape=tuple(dims))
392 if optional is not None:
393 dtype['optional'] = _var_as_dict(optional)
394 else:
395 raise NotImplementedError( # pragma: no cover
396 "Unable to convert a type into a dictionary for '{}'. "
397 "Available fields: {}.".format(
398 var.type, pprint.pformat(dir(var.type))))
399 else:
400 raise NotImplementedError( # pragma: no cover
401 "Unable to convert variable into a dictionary for '{}'. "
402 "Available fields: {}.".format(
403 var, pprint.pformat(dir(var.type))))
405 res = dict(name=var.name, type=dtype)
407 if (hasattr(var, 'sparse_tensor') and dtype.get('elem', None) == 1 and
408 dtype['kind'] == 'sparse_tensor'):
409 # sparse matrix
410 t = var.sparse_tensor
411 try:
412 values = _var_as_dict(t.values)
413 except NotImplementedError as e: # pragma: no cover
414 raise NotImplementedError(
415 "Issue with\n{}\n---".format(var)) from e
416 indices = _var_as_dict(t.indices)
417 res['value'] = _sparse_array(
418 dtype['shape'], values['value'], indices['value'], dtype=numpy.float32)
419 elif hasattr(var, 'floats') and dtype.get('elem', None) == 6:
420 res['value'] = _numpy_array(var.floats, dtype=numpy.float32)
421 elif hasattr(var, 'strings') and dtype.get('elem', None) == 8:
422 res['value'] = _numpy_array(var.strings)
423 elif hasattr(var, 'ints') and dtype.get('elem', None) == 7:
424 res['value'] = _numpy_array(var.ints)
425 elif hasattr(var, 'f') and dtype.get('elem', None) == 1:
426 res['value'] = var.f
427 elif hasattr(var, 's') and dtype.get('elem', None) == 3:
428 res['value'] = var.s
429 elif hasattr(var, 'i') and dtype.get('elem', None) == 2:
430 res['value'] = var.i
431 elif hasattr(var, 'g') and dtype.get('elem', None) == 5:
432 res['value'] = var.g
433 elif hasattr(var, 't') and dtype.get('elem', None) == 4:
434 ts = _var_as_dict(var.t)
435 res['value'] = ts['value']
436 elif hasattr(var, 'sparse_tensor') and dtype.get('elem', None) == 11:
437 ts = _var_as_dict(var.sparse_tensor)
438 res['value'] = ts['value']
439 elif "'value'" in str(var):
440 warnings.warn("No value: {} -- {}".format( # pragma: no cover
441 dtype, str(var).replace("\n", "").replace(" ", "")))
442 return res
444 if hasattr(var, 'op_type'):
445 if hasattr(var, 'attribute'):
446 atts = {}
447 for att in var.attribute:
448 atts[att.name] = _var_as_dict(att)
449 return dict(name=var.name, op_type=var.op_type,
450 domain=var.domain, atts=atts)
451 if hasattr(var, 'dims') and len(var.dims) > 0:
452 # initializer
453 data = _to_array(var)
454 return dict(name=var.name, value=data)
455 if hasattr(var, 'data_type') and var.data_type > 0:
456 data = _to_array(var)
457 return dict(name=var.name, value=data)
458 if isinstance(var, str):
459 return dict(name=var)
460 if str(var) == '':
461 return None
462 raise NotImplementedError( # pragma: no cover
463 "Unable to guess which object it is type is %r value is %r."
464 "" % (type(var), str(var)))
467def get_dtype_shape(obj):
468 """
469 Returns the shape of a tensor.
471 :param obj: onnx object
472 :return: `(dtype, shape)` or `(None, None)` if not applicable
473 """
474 if not hasattr(obj, 'type'):
475 return None
476 t = obj.type
477 if not hasattr(t, 'tensor_type'):
478 return None
479 t = t.tensor_type
480 dtype = t.elem_type
481 if not hasattr(t, 'shape'):
482 return dtype, None
483 shape = t.shape
484 ds = []
485 for dim in shape.dim:
486 d = dim.dim_value
487 s = dim.dim_param
488 if d == 0:
489 if s == '':
490 ds.append(None)
491 else:
492 ds.append(s)
493 else:
494 ds.append(d)
495 return dtype, tuple(ds)
498def onnx_model_opsets(onnx_model):
499 """
500 Extracts opsets in a dictionary.
502 :param onnx_model: ONNX graph
503 :return: dictionary `{domain: version}`
504 """
505 res = {}
506 for oimp in onnx_model.opset_import:
507 res[oimp.domain] = oimp.version
508 return res
511def _type_to_string(dtype):
512 """
513 Converts a type into a readable string.
514 """
515 if not isinstance(dtype, dict):
516 dtype_ = _var_as_dict(dtype) # pragma: no cover
517 else:
518 dtype_ = dtype
519 if dtype_["kind"] == 'tensor':
520 return "{0}({1})".format(dtype_['elem'], dtype_['shape'])
521 if dtype_['kind'] == 'sequence':
522 return "[{0}]".format(_type_to_string(dtype_['elem']))
523 if dtype_["kind"] == 'map':
524 return "{{{0}, {1}}}".format(dtype_['key'], dtype_['value'])
525 raise NotImplementedError( # pragma: no cover
526 "Unable to convert into string {} or {}.".format(dtype, dtype_))
529def numpy_min(x):
530 """
531 Returns the minimum of an array.
532 Deals with text as well.
533 """
534 try:
535 if hasattr(x, 'todense'):
536 x = x.todense()
537 if x.dtype.kind not in 'cUC':
538 return x.min()
539 try: # pragma: no cover
540 x = x.ravel()
541 except AttributeError: # pragma: no cover
542 pass
543 keep = list(filter(lambda s: isinstance(s, str), x))
544 if len(keep) == 0: # pragma: no cover
545 return numpy.nan
546 keep.sort()
547 val = keep[0]
548 if len(val) > 10: # pragma: no cover
549 val = val[:10] + '...'
550 return "%r" % val
551 except (ValueError, TypeError): # pragma: no cover
552 return '?'
555def numpy_max(x):
556 """
557 Returns the maximum of an array.
558 Deals with text as well.
559 """
560 try:
561 if hasattr(x, 'todense'):
562 x = x.todense()
563 if x.dtype.kind not in 'cUC':
564 return x.max()
565 try: # pragma: no cover
566 x = x.ravel()
567 except AttributeError: # pragma: no cover
568 pass
569 keep = list(filter(lambda s: isinstance(s, str), x))
570 if len(keep) == 0: # pragma: no cover
571 return numpy.nan
572 keep.sort()
573 val = keep[-1]
574 if len(val) > 10: # pragma: no cover
575 val = val[:10] + '...'
576 return "%r" % val
577 except (ValueError, TypeError): # pragma: no cover
578 return '?'
581def guess_proto_dtype(dtype):
582 """
583 Guesses the ONNX dtype given a numpy dtype.
585 :param dtype: numpy dtype
586 :return: proto type
587 """
588 if dtype == numpy.float32:
589 return TensorProto.FLOAT # pylint: disable=E1101
590 if dtype == numpy.float64:
591 return TensorProto.DOUBLE # pylint: disable=E1101
592 if dtype == numpy.int64:
593 return TensorProto.INT64 # pylint: disable=E1101
594 if dtype == numpy.int32:
595 return TensorProto.INT32 # pylint: disable=E1101
596 if dtype == numpy.int16:
597 return TensorProto.INT16 # pylint: disable=E1101
598 if dtype == numpy.int8:
599 return TensorProto.INT8 # pylint: disable=E1101
600 if dtype == numpy.uint64:
601 return TensorProto.UINT64 # pylint: disable=E1101
602 if dtype == numpy.uint32:
603 return TensorProto.UINT32 # pylint: disable=E1101
604 if dtype == numpy.uint16:
605 return TensorProto.UINT16 # pylint: disable=E1101
606 if dtype == numpy.uint8:
607 return TensorProto.UINT8 # pylint: disable=E1101
608 if dtype == numpy.float16:
609 return TensorProto.FLOAT16 # pylint: disable=E1101
610 if dtype in (bool, numpy.bool_):
611 return TensorProto.BOOL # pylint: disable=E1101
612 if dtype in (str, numpy.str_):
613 return TensorProto.STRING # pylint: disable=E1101
614 raise RuntimeError(
615 "Unable to guess type for dtype={}.".format(dtype)) # pragma: no cover
618def guess_proto_dtype_name(onnx_dtype):
619 """
620 Returns a string equivalent to `onnx_dtype`.
622 :param dtype: onnx dtype
623 :return: proto type
624 """
625 if onnx_dtype == TensorProto.FLOAT: # pylint: disable=E1101
626 return "TensorProto.FLOAT"
627 if onnx_dtype == TensorProto.DOUBLE: # pylint: disable=E1101
628 return "TensorProto.DOUBLE"
629 if onnx_dtype == TensorProto.INT64: # pylint: disable=E1101
630 return "TensorProto.INT64"
631 if onnx_dtype == TensorProto.INT32: # pylint: disable=E1101
632 return "TensorProto.INT32"
633 if onnx_dtype == TensorProto.INT16: # pylint: disable=E1101
634 return "TensorProto.INT16"
635 if onnx_dtype == TensorProto.UINT8: # pylint: disable=E1101
636 return "TensorProto.UINT8"
637 if onnx_dtype == TensorProto.FLOAT16: # pylint: disable=E1101
638 return "TensorProto.FLOAT16"
639 if onnx_dtype == TensorProto.BOOL: # pylint: disable=E1101
640 return "TensorProto.BOOL"
641 if onnx_dtype == TensorProto.STRING: # pylint: disable=E1101
642 return "TensorProto.STRING"
643 raise RuntimeError( # pragma: no cover
644 "Unable to guess type for dtype={}.".format(onnx_dtype))
647def guess_dtype(proto_type):
648 """
649 Converts a proto type into a :epkg:`numpy` type.
651 :param proto_type: example ``onnx.TensorProto.FLOAT``
652 :return: :epkg:`numpy` dtype
653 """
654 if proto_type == TensorProto.FLOAT: # pylint: disable=E1101
655 return numpy.float32
656 if proto_type == TensorProto.BOOL: # pylint: disable=E1101
657 return numpy.bool_
658 if proto_type == TensorProto.DOUBLE: # pylint: disable=E1101
659 return numpy.float64
660 if proto_type == TensorProto.STRING: # pylint: disable=E1101
661 return numpy.str_
662 if proto_type == TensorProto.INT64: # pylint: disable=E1101
663 return numpy.int64
664 if proto_type == TensorProto.INT32: # pylint: disable=E1101
665 return numpy.int32
666 if proto_type == TensorProto.INT8: # pylint: disable=E1101
667 return numpy.int8
668 if proto_type == TensorProto.INT16: # pylint: disable=E1101
669 return numpy.int16
670 if proto_type == TensorProto.UINT64: # pylint: disable=E1101
671 return numpy.uint64
672 if proto_type == TensorProto.UINT32: # pylint: disable=E1101
673 return numpy.uint32
674 if proto_type == TensorProto.UINT8: # pylint: disable=E1101
675 return numpy.uint8
676 if proto_type == TensorProto.UINT16: # pylint: disable=E1101
677 return numpy.uint16
678 if proto_type == TensorProto.FLOAT16: # pylint: disable=E1101
679 return numpy.float16
680 raise ValueError(
681 "Unable to convert proto_type {} to numpy type.".format(
682 proto_type))
685def to_skl2onnx_type(name, elem_type, shape):
686 """
687 Converts *name*, *elem_type*, *shape* into a
688 :epkg:`sklearn-onnx` type.
690 :param name: string
691 :param elem_type: tensor of elements of this type
692 :param shape: expected shape
693 :return: data type
694 """
695 from skl2onnx.common.data_types import _guess_numpy_type # delayed
696 elem = guess_numpy_type_from_string(elem_type)
697 shape = list(None if d == 0 else d for d in shape)
698 return (name, _guess_numpy_type(elem, shape))
701def from_pb(obj):
702 """
703 Extracts tensor description from a protobuf.
705 :param obj: initializer, tensor
706 :return: (name, type, shape)
707 """
708 def get_dim(d):
709 r = d.dim_value
710 if "dim_param" in str(d):
711 return None
712 if r == 0:
713 # dim_value is 0 when it is 0 or undefined
714 return 0 if "0" in str(d) else None
715 return r
717 def get_shape(tt):
718 return [get_dim(tt.shape.dim[i])
719 for i in range(len(tt.shape.dim))]
721 if hasattr(obj, 'extend'):
722 return [from_pb(o) for o in obj]
724 name = obj.name
725 if obj.type.tensor_type:
726 tt = obj.type.tensor_type
727 elem = tt.elem_type
728 shape = get_shape(tt)
729 if elem not in TENSOR_TYPE_TO_NP_TYPE:
730 raise NotImplementedError(
731 "Unsupported type '{}' (elem_type={}).".format(
732 type(obj.type.tensor_type), elem))
733 ty = TENSOR_TYPE_TO_NP_TYPE[elem].type
734 else:
735 raise NotImplementedError("Unsupported type '{}' as "
736 "a string ({}).".format(
737 type(obj), obj))
739 return (name, ty, shape)
742def numpy_type_prototype(dtype):
743 """
744 Converts a numpy dtyp into a TensorProto dtype.
746 :param dtype: dtype
747 :return: proto dtype
748 """
749 if dtype in NP_TYPE_TO_TENSOR_TYPE:
750 return NP_TYPE_TO_TENSOR_TYPE[dtype]
751 dt = numpy.dtype(dtype)
752 if dt in NP_TYPE_TO_TENSOR_TYPE:
753 return NP_TYPE_TO_TENSOR_TYPE[dt]
754 raise ValueError( # pragma: no cover
755 "Unable to convert dtype %r into ProtoType." % dtype)
758def make_value_info(name, dtype, shape):
759 """
760 Converts a variable defined by its name, type and shape
761 into `onnx.ValueInfoProto`.
763 :return: instance of `onnx.ValueInfoProto`
764 """
765 value_info = ValueInfoProto()
766 value_info.name = name
767 tensor_type_proto = make_tensor_type_proto(
768 numpy_type_prototype(dtype), shape)
769 value_info.type.CopyFrom(tensor_type_proto) # pylint: disable=E1101
770 return value_info
773_get_onnx_function_cache = None
776def _get_onnx_function():
777 """
778 Returns the list of functions defined in ONNX package.
779 """
780 global _get_onnx_function_cache # pylint: disable=W0603
781 if _get_onnx_function_cache is None:
782 _get_onnx_function_cache = {}
783 fcts = get_function_ops()
784 for fct in fcts:
785 key = fct.domain, fct.name
786 if key in _get_onnx_function_cache:
787 raise RuntimeError( # pragma: no cover
788 "Function %r is already registered." % (key, ))
789 _get_onnx_function_cache[key] = fct
790 return _get_onnx_function_cache
793def get_onnx_schema(opname, domain='', opset=None, load_function=False):
794 """
795 Returns the operator schema for a specific operator.
797 :param domain: operator domain
798 :param opname: operator name
799 :param opset: opset or version, None for the latest
800 :param load_function: loads the function, if True, the function
801 looks into the list of function if one of them has the same name,
802 opset must be None in that case
803 :return: :epkg:`OpSchema`
804 """
805 if load_function:
806 if opset is not None:
807 raise ValueError(
808 "opset must be None if load_function is True for "
809 "operator (%r,%r)." % (domain, opname))
810 fcts = _get_onnx_function()
811 key = domain, opname
812 if key in fcts:
813 return fcts[key]
814 if opset is None:
815 opset = onnx_opset_version()
816 return get_schema(opname, opset, domain)
817 if opset is None:
818 opset = onnx_opset_version()
819 return get_schema(opname, opset, domain)