Coverage for onnxcustom/utils/onnx_helper.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=C0415,E0611,E1101
2"""
3@file
4@brief Onnx implementation of common functions used to train a model.
5"""
6import math
7import numpy
8from onnx import TensorProto, numpy_helper, helper
9from onnxruntime import OrtValue
10from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue
13def onnx_rename_weights(onx):
14 """
15 Renames ONNX initializers to make sure their name
16 follows the alphabetical order. The model is
17 modified inplace. This function calls
18 :func:`onnx_rename_names
19 <mlprodict.onnx_tools.onnx_manipulations.onnx_rename_names>`.
21 :param onx: ONNX model
22 :return: same model
24 .. note::
25 The function does not go into subgraphs.
26 """
27 from mlprodict.onnx_tools.onnx_manipulations import ( # pylint: disable=C0415
28 onnx_rename_names)
30 init = [init.name for init in onx.graph.initializer]
31 ninit = max(1, int(math.log(len(init)) / math.log(10) + 1))
32 fmt = "I%0{}d_%s".format(ninit)
33 new_names = [fmt % (i, name) for i, name in enumerate(init)]
34 repl = dict(zip(init, new_names))
35 return onnx_rename_names(onx, recursive=False, replace=repl)
38def get_onnx_opset(onx, domain=''):
39 """
40 Returns the opset associated to an opset.
42 :param onx: onx graph
43 :param domain: domain
44 :return: value
45 """
46 for opset in onx.opset_import:
47 if opset.domain == domain:
48 return opset.version
49 raise ValueError(
50 "Unable to find opset for domain=%r." % domain)
53def proto_type_to_dtype(proto_type):
54 """
55 Converts a ONNX TensorProto type into numpy type.
57 :param proto_type: integer
58 :return: proto type
59 """
60 if proto_type == TensorProto.FLOAT:
61 return numpy.float32
62 if proto_type == TensorProto.DOUBLE:
63 return numpy.float64
64 # Not efficient.
65 if proto_type == 'tensor(float)':
66 return numpy.float32
67 if proto_type == 'tensor(double)':
68 return numpy.float64
69 raise ValueError(
70 "Unexpected value proto_type=%r (type=%r)." % (
71 proto_type, type(proto_type)))
74def dtype_to_var_type(dtype):
75 """
76 Converts a numpy dtype into a var type.
77 """
78 from skl2onnx.common.data_types import (
79 FloatTensorType, DoubleTensorType,
80 Int32TensorType, Int64TensorType)
81 if dtype == numpy.float32:
82 return FloatTensorType
83 if dtype == numpy.float64:
84 return DoubleTensorType
85 if dtype == numpy.int64:
86 return Int64TensorType
87 if dtype == numpy.int32:
88 return Int32TensorType
89 raise ValueError(
90 "Unexpected value dtype=%r." % dtype)
93def _finalize_new_onnx(graph, onx):
94 onnx_model = helper.make_model(graph)
95 onnx_model.ir_version = onx.ir_version
96 onnx_model.producer_name = onx.producer_name
97 onnx_model.producer_version = onx.producer_version
98 onnx_model.domain = onx.domain
99 onnx_model.model_version = onx.model_version
100 onnx_model.doc_string = onx.doc_string
101 if len(onx.metadata_props) > 0: # pragma: no cover
102 values = {p.key: p.value for p in onx.metadata_props}
103 helper.set_model_props(onnx_model, values)
105 del onnx_model.opset_import[:] # pylint: disable=E1101
106 for oimp in onx.opset_import:
107 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
108 op_set.domain = oimp.domain
109 op_set.version = oimp.version
110 return onnx_model
113def add_initializer(model, name, value):
114 """
115 Adds an initializer to graph.
117 :param model: onnx model
118 :param name: initializer name
119 :param value: value
120 :return: new ONNX graph
121 """
122 inits = set(i.name for i in model.graph.initializer)
123 if name in inits:
124 raise ValueError( # pragma: no cover
125 "Name %r is already taken among %r." % (
126 name, inits))
127 list_inits = list(model.graph.initializer)
128 list_inits.append(
129 numpy_helper.from_array(value, name=name))
130 graph_def = helper.make_graph(
131 model.graph.node, model.graph.name,
132 model.graph.input, model.graph.output,
133 list_inits)
134 return _finalize_new_onnx(graph_def, model)
137def replace_initializers_into_onnx(model, results):
138 """
139 Replaces initializers by other initializers,
140 usually trained ones.
142 :param model: onnx graph
143 :param results: results to be added in a dictionary
144 :return: new onnx graph
145 """
146 inputs = list(model.graph.input)
147 outputs = list(model.graph.output)
148 inits = list(model.graph.initializer)
150 inits_dict = {init.name: i for i, init in enumerate(inits)}
151 for k, v in results.items():
152 if k in inits_dict:
153 if isinstance(v, numpy.ndarray):
154 v = numpy_helper.from_array(v, k)
155 elif isinstance(v, (C_OrtValue, OrtValue)):
156 v = numpy_helper.from_array(v.numpy(), k)
157 inits[inits_dict[k]] = v
158 else:
159 raise RuntimeError( # pragma: no cover
160 "Unable to find initializer %r in "
161 "%r." % (k, inits_dict))
163 graph = helper.make_graph(
164 list(model.graph.node), model.graph.name, inputs,
165 outputs, inits)
166 return _finalize_new_onnx(graph, model)