Coverage for mlprodict/sklapi/onnx_transformer.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# coding: utf-8
2"""
3@file
4@brief Wraps runtime into a :epkg:`scikit-learn` transformer.
5"""
6from io import BytesIO
7import numpy
8import pandas
9import onnx
10from sklearn.base import BaseEstimator, TransformerMixin
11from skl2onnx.algebra.onnx_operator_mixin import OnnxOperatorMixin
12from mlprodict.onnx_tools.onnx_manipulations import (
13 select_model_inputs_outputs, enumerate_model_node_outputs)
14from ..onnx_tools.onnx2py_helper import _var_as_dict, onnx_model_opsets
15from ..onnx_tools.exports.skl2onnx_helper import add_onnx_graph
16from ..onnxrt import OnnxInference
19class OnnxTransformer(BaseEstimator, TransformerMixin, OnnxOperatorMixin):
20 """
21 Calls :epkg:`onnxruntime` or the runtime implemented
22 in this package to transform input based on a ONNX graph.
23 It follows :epkg:`scikit-learn` API
24 so that it can be included in a :epkg:`scikit-learn` pipeline.
25 See notebook :ref:`transferlearningrst` for an example.
27 :param onnx_bytes: bytes
28 :param output_name: string
29 requested output name or None to request all and
30 have method *transform* to store all of them in a dataframe
31 :param enforce_float32: boolean
32 :epkg:`onnxruntime` only supports *float32*,
33 :epkg:`scikit-learn` usually uses double floats, this parameter
34 ensures that every array of double floats is converted into
35 single floats
36 :param runtime: string, defined the runtime to use
37 as described in @see cl OnnxInference.
38 :param change_batch_size: some models are converted for
39 a specific batch size, this parameter changes it,
40 None to avoid changing it, 0 to fix an undefined
41 first dimension
42 :param reshape: reshape the output to get
43 a matrix and not a multidimensional array
44 """
46 def __init__(self, onnx_bytes, output_name=None, enforce_float32=True,
47 runtime='python', change_batch_size=None, reshape=False):
48 BaseEstimator.__init__(self)
49 TransformerMixin.__init__(self)
50 self.onnx_bytes = (onnx_bytes
51 if not hasattr(onnx_bytes, 'SerializeToString')
52 else onnx_bytes.SerializeToString())
53 self.output_name = output_name
54 self.enforce_float32 = enforce_float32
55 self.runtime = runtime
56 self.change_batch_size = change_batch_size
57 self.reshape = reshape
59 def __repr__(self): # pylint: disable=W0222
60 """
61 usual
62 """
63 ob = self.onnx_bytes
64 if len(ob) > 20:
65 ob = ob[:10] + b"..." + ob[-10:]
66 return ("{0}(onnx_bytes={1}, output_name={2}, enforce_float32={3}, "
67 "runtime='{4}')".format(
68 self.__class__.__name__, ob, self.output_name,
69 self.enforce_float32, self.runtime))
71 def fit(self, X=None, y=None, **fit_params):
72 """
73 Loads the :epkg:`ONNX` model.
75 :param X: unused
76 :param y: unused
77 :param fit_params: additional parameter (unused)
78 :return: self
79 """
80 from ..onnx_tools.optim.onnx_helper import change_input_first_dimension
81 onx = onnx.load(BytesIO(self.onnx_bytes))
82 self.op_version = onnx_model_opsets(onx)
84 output_names = set(
85 o.name for o in onx.graph.output) # pylint: disable=E1101
86 updated = False
87 if (self.output_name is not None and
88 self.output_name not in output_names):
89 # The model refers to intermediate outputs.
90 onx = select_model_inputs_outputs(
91 onx, outputs=[self.output_name])
92 updated = True
94 if self.change_batch_size is not None:
95 onx = change_input_first_dimension(
96 onx, self.change_batch_size)
97 updated = True
99 onnx_bytes = (
100 onx.SerializeToString() if updated else self.onnx_bytes)
101 self.onnxrt_ = OnnxInference(
102 onnx_bytes, runtime=self.runtime,
103 runtime_options=dict(log_severity_level=3))
104 self.inputs_ = self.onnxrt_.input_names
105 self.inputs_shape_types_ = self.onnxrt_.input_names_shapes_types
106 return self
108 def _check_arrays(self, inputs):
109 """
110 Ensures that double floats are converted into single floats
111 if *enforce_float32* is True or raises an exception.
112 """
113 has = hasattr(self, "onnxrt_")
114 sht = self.inputs_shape_types_ if has else None
115 if sht is not None and len(sht) < len(inputs):
116 raise RuntimeError( # pragma: no cover
117 "Unexpected number of inputs {} > {} (expected).".format(
118 len(inputs), len(sht)))
119 for i, k in enumerate(inputs):
120 v = inputs[k]
121 if isinstance(v, numpy.ndarray):
122 if v.dtype == numpy.float64 and self.enforce_float32:
123 inputs[k] = v.astype(numpy.float32)
124 continue
125 if not has:
126 continue
127 exp = sht[i]
128 if exp[1] != ('?', ) and exp[1][1:] != v.shape[1:]:
129 raise RuntimeError( # pragma: no cover
130 "Unexpected shape for input '{}': {} != {} "
131 "(expected).".format(
132 k, v.shape, exp[1]))
133 if ((v.dtype == numpy.float32 and exp[2] != 'tensor(float)') or
134 (v.dtype == numpy.float64 and exp[2] != 'tensor(double)')):
135 raise TypeError( # pragma: no cover
136 "Unexpected dtype for input '{}': {} != {} "
137 "(expected).".format(
138 k, v.dtype, exp[2]))
140 def transform(self, X, y=None, **inputs):
141 """
142 Runs the predictions. If *X* is a dataframe,
143 the function assumes every columns is a separate input,
144 otherwise, *X* is considered as a first input and *inputs*
145 can be used to specify extra inputs.
147 :param X: iterable, data to process
148 (or first input if several expected)
149 :param y: unused
150 :param inputs: :epkg:`ONNX` graph support multiple inputs,
151 each column of a dataframe is converted into as many inputs if
152 *X* is a dataframe, otherwise, *X* is considered as the first input
153 and *inputs* can be used to specify the other ones
154 :return: :epkg:`DataFrame`
155 """
156 if not hasattr(self, "onnxrt_"):
157 raise AttributeError( # pragma: no cover
158 "Transform OnnxTransformer must be fit first.")
159 rt_inputs = {}
160 if isinstance(X, numpy.ndarray):
161 rt_inputs[self.inputs_[0]] = X
162 elif isinstance(X, pandas.DataFrame):
163 for c in X.columns:
164 rt_inputs[c] = X[c]
165 elif isinstance(X, dict) and len(inputs) == 0:
166 for k, v in X.items():
167 rt_inputs[k] = v
168 elif isinstance(X, list):
169 if len(self.inputs_) == 1:
170 rt_inputs[self.inputs_[0]] = numpy.array(X)
171 else:
172 for i in range(len(self.inputs_)): # pylint: disable=C0200
173 rt_inputs[self.inputs_[i]] = [row[i] for row in X]
175 for k, v in inputs.items():
176 rt_inputs[k] = v
178 names = ([self.output_name]
179 if self.output_name else self.onnxrt_.output_names)
180 self._check_arrays(rt_inputs)
181 doutputs = self.onnxrt_.run(rt_inputs)
182 outputs = [doutputs[n] for n in names]
184 if self.reshape:
185 n = outputs[0].shape[0]
186 outputs = [o.reshape((n, -1)) for o in outputs]
188 if self.output_name or len(outputs) == 1:
189 if isinstance(outputs[0], list):
190 return pandas.DataFrame(outputs[0])
191 return outputs[0]
193 names = self.output_name if self.output_name else [
194 o for o in self.onnxrt_.output_names]
195 concat = []
196 colnames = []
197 for k, v in zip(names, outputs):
198 if isinstance(v, numpy.ndarray):
199 if len(v.shape) == 1:
200 v = v.reshape((-1, 1))
201 colnames.append(k)
202 elif len(v.shape) == 2:
203 colnames.extend("%s%d" % (k, i) for i in range(v.shape[1]))
204 else:
205 raise RuntimeError( # pragma: no cover
206 "Unexpected shape for results %r: %r." % (k, v.shape))
207 if isinstance(v, list):
208 if len(v) == 0:
209 raise RuntimeError( # pragma: no cover
210 "Output %r is empty." % k)
211 if not isinstance(v[0], dict):
212 raise RuntimeError( # pragma: no cover
213 "Unexpected type for output %r - value=%r."
214 "" % (k, v[0]))
215 df = pandas.DataFrame(v)
216 cols = list(sorted(df.columns))
217 v = df[cols].copy().values
218 colnames.extend("%s%d" % (k, i) for i in range(v.shape[1]))
219 concat.append(v)
220 res = numpy.hstack(concat)
221 return pandas.DataFrame(res, columns=colnames)
223 def fit_transform(self, X, y=None, **inputs):
224 """
225 Loads the *ONNX* model and runs the predictions.
227 :param X: iterable, data to process
228 (or first input if several expected)
229 :param y: unused
230 :param inputs: :epkg:`ONNX` graph support multiple inputs,
231 each column of a dataframe is converted into as many inputs if
232 *X* is a dataframe, otherwise, *X* is considered as the first input
233 and *inputs* can be used to specify the other ones
234 :return: :epkg:`DataFrame`
235 """
236 return self.fit(X, y=y, **inputs).transform(X, y)
238 @staticmethod
239 def enumerate_create(onnx_bytes, output_names=None, enforce_float32=True):
240 """
241 Creates multiple *OnnxTransformer*,
242 one for each requested intermediate node.
244 onnx_bytes : bytes
245 output_names: string
246 requested output names or None to request all and
247 have method *transform* to store all of them in a dataframe
248 enforce_float32 : boolean
249 :epkg:`onnxruntime` only supports *float32*,
250 :epkg:`scikit-learn` usually uses double floats, this parameter
251 ensures that every array of double floats is converted into
252 single floats
253 :return: iterator on OnnxTransformer *('output name', OnnxTransformer)*
254 """
255 selected = None if output_names is None else set(output_names)
256 model = onnx.load(BytesIO(onnx_bytes))
257 for out in enumerate_model_node_outputs(model):
258 m = select_model_inputs_outputs(model, out)
259 if selected is None or out in selected:
260 tr = OnnxTransformer(m.SerializeToString(),
261 enforce_float32=enforce_float32)
262 yield out, tr
264 def onnx_parser(self):
265 """
266 Returns a parser for this model.
267 """
268 def parser(scope=None, inputs=None):
269 if scope is None:
270 raise RuntimeError( # pragma: no cover
271 "scope cannot be None (parser of class %r)."
272 "" % type(self))
273 if inputs is None:
274 raise RuntimeError( # pragma: no cover
275 "inputs cannot be None (parser of class %r)."
276 "" % type(self))
277 if (not hasattr(self, 'onnxrt_') or
278 not hasattr(self.onnxrt_, 'output_names')):
279 raise RuntimeError( # pragma: no cover
280 'OnnxTransformer not fit.')
281 if len(inputs) != len(self.inputs_):
282 raise RuntimeError( # pragma: no cover
283 "Mismatch between the number of inputs, expected %r, "
284 "got %r." % (self.inputs_, inputs))
285 return self.onnxrt_.output_names
286 return parser
288 def onnx_shape_calculator(self):
289 def shape_calculator(operator):
290 from skl2onnx.common.data_types import ( # delayed
291 FloatTensorType, DoubleTensorType, Int64TensorType)
292 cout = self.onnxrt_.output_names
293 if len(operator.outputs) != len(cout):
294 raise RuntimeError( # pragma: no cover
295 "Mismatched number of outputs: {} != {}."
296 "".format(len(operator.outputs), len(cout)))
297 for out_op, out in zip(operator.outputs, self.onnxrt_.obj.graph.output):
298 var = _var_as_dict(out)
299 if var['type']['kind'] != 'tensor':
300 raise NotImplementedError( # pragma: no cover
301 "Noy yet implemented for output:\n{}".format(out))
302 shape = var['type']['shape']
303 if shape[0] == 0:
304 shape = (None,) + tuple(shape[1:])
305 elem = var['type']['elem']
306 if elem == 'float':
307 out_op.type = FloatTensorType(shape=shape)
308 elif elem == 'int64':
309 out_op.type = Int64TensorType(shape=shape)
310 elif elem == 'double':
311 out_op.type = DoubleTensorType(shape=shape)
312 else:
313 raise NotImplementedError( # pragma: no cover
314 "Not yet implemented for elem_type: %r" % (elem, ))
315 return shape_calculator
317 def onnx_converter(self):
318 """
319 Returns a converter for this model.
320 If not overloaded, it fetches the converter
321 mapped to the first *scikit-learn* parent
322 it can find.
323 """
324 def converter(scope, operator, container, onnx_model=None):
325 op = operator.raw_operator
326 onx = onnx_model or op.onnxrt_.obj
327 add_onnx_graph(scope, operator, container, onx)
329 return converter
331 @property
332 def opsets(self):
333 """
334 Returns the opsets as dictionary ``{domain: opset}``.
335 """
336 if hasattr(self, 'onnxrt_'):
337 model = self.onnxrt_.obj
338 else:
339 model = onnx.load(BytesIO(self.onnx_bytes))
340 res = {}
341 for oimp in model.opset_import:
342 res[oimp.domain] = oimp.version
343 return res