Coverage for mlprodict/onnx_tools/exports/numpy_helper.py: 88%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Numpy helpers for the conversion from onnx to numpy.
4"""
5import numpy
8def make_slice(data, starts, ends, axes=None, steps=None):
9 """
10 Implements operator slice in numpy.
12 :param data: input
13 :param starts: mandatory
14 :param ends: mandatory
15 :param axes: optional
16 :param steps: optional
17 :return: results
18 """
19 slices = [slice(0, data.shape[i]) for i in range(len(data.shape))]
20 if axes is None:
21 axes = range(len(starts))
22 for i, a in enumerate(axes):
23 if steps is None:
24 slices[a] = slice(starts[i], ends[i])
25 else:
26 slices[a] = slice(starts[i], ends[i], steps[i])
27 return data[slices]
30def argmax_use_numpy_select_last_index(
31 data, axis=0, keepdims=True, select_last_index=False):
32 """
33 Needed or operator `ArgMax`.
34 """
35 if not select_last_index:
36 result = numpy.argmax(data, axis=axis)
37 if keepdims and len(result.shape) < len(data.shape):
38 result = numpy.expand_dims(result, axis)
39 return result.astype(numpy.int64)
41 data = numpy.flip(data, axis)
42 result = numpy.argmax(data, axis=axis)
43 result = data.shape[axis] - result - 1
44 if keepdims:
45 result = numpy.expand_dims(result, axis)
46 return result.astype(numpy.int64)
49def argmin_use_numpy_select_last_index(
50 data, axis=0, keepdims=True, select_last_index=False):
51 """
52 Needed or operator `ArgMin`.
53 """
54 if not select_last_index:
55 result = numpy.argmin(data, axis=axis)
56 if keepdims and len(result.shape) < len(data.shape):
57 result = numpy.expand_dims(result, axis)
58 return result.astype(numpy.int64)
60 data = numpy.flip(data, axis)
61 result = numpy.argmin(data, axis=axis)
62 result = data.shape[axis] - result - 1
63 if keepdims:
64 result = numpy.expand_dims(result, axis)
65 return result.astype(numpy.int64)
68def array_feature_extrator(data, indices):
69 """
70 Implementation of operator *ArrayFeatureExtractor*
71 with :epkg:`numpy`.
72 """
73 if len(indices.shape) == 2 and indices.shape[0] == 1:
74 index = indices.ravel().tolist()
75 add = len(index)
76 elif len(indices.shape) == 1:
77 index = indices.tolist()
78 add = len(index)
79 else:
80 add = 1
81 for s in indices.shape:
82 add *= s
83 index = indices.ravel().tolist()
84 if len(data.shape) == 1:
85 new_shape = (1, add)
86 else:
87 new_shape = list(data.shape[:-1]) + [add]
88 tem = data[..., index]
89 res = tem.reshape(new_shape)
90 return res
93class NumpyCode:
94 """
95 Converts an ONNX operators into :epkg:`numpy` code.
97 :param opset: target opset for the conversion (usually unused)
98 :param name: node name
99 :param op_type: operator type
100 :param domain: domain
101 :param inputs: inputs
102 :param outputs: outputs
103 :param attributes: attributes
104 :param used: dictionary `{k: v}`,
105 list of nodes taking *k* as input
106 :param context: whole context
107 :param mark_inits: marks initializer as replaced
108 :param indent: indentation of the second line and following
109 :return: code as str
110 """
112 def __init__(self, opset, name=None, op_type=None, domain='',
113 inputs=None, outputs=None, attributes=None,
114 used=None, context=None, mark_inits=None,
115 indent="", **unused):
116 self.opset = opset
117 self.name = name
118 self.op_type = op_type
119 self.domain = domain
120 self.inputs = inputs
121 self.outputs = outputs
122 self.attributes = attributes
123 self.used = used
124 self.context = context
125 self.mark_inits = mark_inits
126 self.unused = unused
127 self.indent = indent
129 def _make_sure_inputs(self, n, m=None):
130 if m is None:
131 m = n
132 if len(self.inputs) < n:
133 raise RuntimeError( # pragma: no cover
134 "Expecting at least %d inputs for operator %r not %r." % (
135 n, self.op_type, self.inputs))
136 if len(self.inputs) > m:
137 raise RuntimeError( # pragma: no cover
138 "Expecting at most %d inputs for operator %r not %r." % (
139 m, self.op_type, self.inputs))
141 def _make_sure_opsets(self, mi, ma=None):
142 if mi is not None and self.opset < mi:
143 raise RuntimeError( # pragma: no cover
144 "Cannot convert operator type %d, opset %d < %d." % (
145 self.op_type, self.opset, mi))
146 if ma is not None and self.opset > ma:
147 raise RuntimeError( # pragma: no cover
148 "Cannot convert operator type %d, opset %d > %d." % (
149 self.op_type, self.opset, mi))
151 def _getat(self, name, defval=None, format=None):
153 def f(v):
154 if format is None:
155 return v
156 if format == 'listint' and isinstance(v, str):
157 return list(
158 map(int, v.strip('[]').replace(' ', '').split(',')))
159 if format == 'listfloat' and isinstance(v, str):
160 return list(
161 map(float, v.strip('[]').replace(' ', '').split(',')))
162 raise ValueError( # pragma: no cover
163 "Unable to convert %r with format=%r." % (v, format))
165 for n, val in self.attributes:
166 if name == n:
167 return f(val)
168 return defval
170 def _simplify(self, name, kind):
171 value = None
172 if (self.used is not None and name in self.used and
173 len(self.used[name]) == 1 and self.context is not None):
174 inits = self.context['initializers_dict']
175 if name in inits:
176 v = inits[name]
177 if v.dtype == numpy.int64 and v.size < 10:
178 value = v
179 if name not in self.mark_inits:
180 self.mark_inits[name] = []
181 self.mark_inits[name].append(v)
183 if kind == 'tuple':
184 if value is None:
185 return "tuple(%s)" % name
186 if value.size == 1:
187 return str(tuple(value)[0])
188 return str(tuple(value))
189 elif kind == 'list':
190 if value is None:
191 return name
192 if len(value.shape) == 0:
193 return str(value)
194 return str(list(value))
195 raise NotImplementedError( # pragma: no cover
196 "Unknown scenario to simplify (%r)." % kind)
198 @staticmethod
199 def _make_tuple(val):
200 if isinstance(val, tuple):
201 return val
202 if isinstance(val, list):
203 return tuple(val)
204 if isinstance(val, int):
205 return val
206 if isinstance(val, str):
207 return tuple(map(int, val.strip('()[]').replace(" ", "").split(",")))
208 raise NotImplementedError( # pragma: no cover
209 "Unable to convert %r into tuple." % val)
211 def make_numpy_code(self):
212 """
213 Main method, returns the python code for a given
214 operator.
215 """
216 if self.domain == '':
217 return self._make_numpy_code_onnx()
219 if self.domain == 'ai.onnx.ml':
220 return self._make_numpy_code_onnxml()
222 if self.domain == 'com.microsoft':
223 return self._make_numpy_code_others()
225 raise NotImplementedError( # pragma: no cover
226 "Unable to convert any operator from domain %r." % self.domain)
228 def _make_numpy_code_onnx(self):
230 binary_ops = dict(Add='+', Sub='-', Div='/', Mul='*', MatMul='@',
231 Pow='**')
232 unary_ops = dict(Neg='-')
233 unary_ops_ = dict(Sqrt='** 0.5')
235 outs = ", ".join(self.outputs)
237 if self.op_type in binary_ops:
238 self._make_sure_inputs(2)
239 return "%s = %s %s %s" % (
240 outs, self.inputs[0], binary_ops[self.op_type],
241 self.inputs[1])
243 if self.op_type in unary_ops:
244 self._make_sure_inputs(1)
245 return "%s = %s %s" % (
246 outs, unary_ops[self.op_type], self.inputs[0])
248 if self.op_type in unary_ops_:
249 self._make_sure_inputs(1)
250 return "%s = %s %s" % (
251 outs, self.inputs[0], unary_ops_[self.op_type])
253 if self.op_type == 'ArgMax':
254 self._make_sure_opsets(12)
255 self._make_sure_inputs(1)
256 axis = self._getat('axis', 0)
257 keepdims = self._getat('keepdims', 1)
258 select_last_index = self._getat('keepdims', 0)
259 if select_last_index:
260 return (
261 "%s = argmax_use_numpy_select_last_index("
262 "%s, axis=%s, keepdims=%s, select_last_index=%s)" % (
263 outs, self.inputs[0], axis, keepdims, select_last_index))
264 if keepdims:
265 return "%s = numpy.expand_dims(numpy.argmax(%s, axis=%s), -1)" % (
266 outs, self.inputs[0], axis)
267 return "%s = numpy.argmax(%s, axis=%s)" % (
268 outs, self.inputs[0], axis)
270 if self.op_type == 'ArgMin':
271 self._make_sure_opsets(12)
272 self._make_sure_inputs(1)
273 axis = self._getat('axis', 0)
274 keepdims = self._getat('keepdims', 1)
275 select_last_index = self._getat('keepdims', 0)
276 if select_last_index:
277 return (
278 "%s = argmin_use_numpy_select_last_index("
279 "%s, axis=%s, keepdims=%s, select_last_index=%s)" % (
280 outs, self.inputs[0], axis, keepdims, select_last_index))
281 if keepdims:
282 return "%s = numpy.expand_dims(numpy.argmin(%s, axis=%s), -1)" % (
283 outs, self.inputs[0], axis)
284 return "%s = numpy.argmin(%s, axis=%s)" % (
285 outs, self.inputs[0], axis)
287 if self.op_type == 'Cast':
288 from ..onnx2py_helper import _elem_type_as_str
289 self._make_sure_inputs(1)
290 to = int(self._getat('to', 1))
291 dtype = _elem_type_as_str(to)
292 dtype = {'double': 'float64', 'float': 'float32'}.get(dtype, dtype)
293 return "%s = %s.astype(numpy.%s)" % (outs, self.inputs[0], dtype)
295 if self.op_type == 'Concat':
296 axis = self._getat('axis', 0)
297 return "%s = numpy.concatenate([%s], %s)" % (
298 outs, ", ".join(self.inputs), axis)
300 if self.op_type == 'ConstantOfShape':
301 self._make_sure_opsets(9)
302 self._make_sure_inputs(1)
303 value = self._getat('value', 0, format='listfloat')
304 shape = self._simplify(self.inputs[0], kind='tuple')
305 return "%s = numpy.full(%s, %s)" % (
306 outs, shape, value)
308 if self.op_type == 'Exp':
309 return "%s = numpy.exp(%s)" % (outs, self.inputs[0])
311 if self.op_type == 'Max':
312 return "%s = numpy.maximum(%s)" % (outs, ", ".join(self.inputs))
314 if self.op_type == 'Gather':
315 self._make_sure_opsets(11)
316 self._make_sure_inputs(2)
317 axis = self._getat('axis', 0)
318 return "%s = numpy.take(%s, %s, axis=%s)" % (
319 outs, self.inputs[0],
320 self._simplify(self.inputs[1], 'list'), axis)
322 if self.op_type == 'Gemm':
323 self._make_sure_inputs(2, 3)
324 alpha = self._getat('alpha', 0.)
325 transA = self._getat('transA', 0)
326 transB = self._getat('transB', 0)
327 ta = ".T" if transA in ('1', 1, True) else ""
328 tb = ".T" if transB in ('1', 1, True) else ""
329 if len(self.inputs) == 2:
330 return "%s = %s%s @ %s%s * %s" % (
331 outs, self.inputs[0], ta, self.inputs[1], tb, alpha)
332 beta = self._getat('beta', 0.)
333 return "%s = %s%s @ %s%s * %s + %s * %s" % (
334 outs, self.inputs[0], ta, self.inputs[1], tb, alpha,
335 self.inputs[2], beta)
337 if self.op_type == 'Identity':
338 return "%s = %s" % (outs, self.inputs[0])
340 if self.op_type == 'ReduceProd':
341 self._make_sure_inputs(1)
342 axes = self._getat('axes', "[0]")
343 keepdims = self._getat('keepdims', 0)
344 return "%s = %s.prod(axis=tuple(%s), keepdims=%s)" % (
345 outs, self.inputs[0], axes, keepdims)
347 if self.op_type == 'ReduceSum':
348 self._make_sure_opsets(11)
349 self._make_sure_inputs(2)
350 keepdims = self._getat('keepdims', 0)
351 return "%s = %s.sum(axis=%s, keepdims=%s)" % (
352 outs, self.inputs[0], self._simplify(self.inputs[1], 'tuple'),
353 keepdims)
355 if self.op_type == 'ReduceSumSquare':
356 self._make_sure_inputs(1)
357 axes = self._getat('axes', "[0]")
358 keepdims = self._getat('keepdims', 0)
359 return "%s = (%s ** 2).sum(axis=tuple(%s), keepdims=%s)" % (
360 outs, self.inputs[0], axes, keepdims)
362 if self.op_type == 'Reshape':
363 self._make_sure_inputs(2)
364 simp = self._simplify(self.inputs[1], 'tuple')
365 return "%s = %s.reshape(%s)" % (
366 outs, self.inputs[0], simp)
368 if self.op_type == 'Shape':
369 self._make_sure_inputs(1)
370 return "%s = numpy.array(%s.shape, dtype=numpy.int64)" % (
371 outs, self.inputs[0])
373 if self.op_type == 'Slice':
374 return "%s = make_slice(%s)" % (outs, ", ".join(self.inputs))
376 if self.op_type == 'Softmax':
377 self._make_sure_inputs(1)
378 axis = self._getat('axis', -1)
379 return "%s = scipy_special.softmax(%s, axis=%s)" % (
380 outs, self.inputs[0], axis)
382 if self.op_type == 'Squeeze':
383 self._make_sure_opsets(13)
384 self._make_sure_inputs(2)
385 return "%s = numpy.squeeze(%s, axis=%s)" % (
386 outs, self.inputs[0], self._simplify(self.inputs[1], 'tuple'))
388 if self.op_type == 'Transpose':
389 self._make_sure_inputs(1)
390 perm = self._getat('perm', None)
391 return "%s = numpy.transpose(%s, axes=%s)" % (
392 outs, self.inputs[0], self._make_tuple(perm))
394 if self.op_type == 'Unsqueeze':
395 self._make_sure_opsets(13)
396 self._make_sure_inputs(2)
397 return "%s = numpy.expand_dims(%s, axis=%s)" % (
398 outs, self.inputs[0],
399 self._simplify(self.inputs[1], 'tuple'))
401 raise NotImplementedError( # pragma: no cover
402 "Unable to convert operator type %r name=%r." % (
403 self.op_type, self.name))
405 def _make_numpy_code_onnxml(self):
406 outs = ", ".join(self.outputs)
408 if self.op_type == 'ArrayFeatureExtractor':
409 self._make_sure_inputs(2)
410 return "%s = array_feature_extrator(%s, %s)" % (
411 outs, self.inputs[0], self.inputs[1])
413 if self.op_type == 'LinearClassifier':
414 multi_class = self._getat('targets', 0)
415 if multi_class != 0:
416 raise NotImplementedError( # pragma: no cover
417 "Conversion of operator %r with multi_class=%r "
418 "is not implemented." % (self.op_type, multi_class))
419 self._make_sure_inputs(1)
420 coefficients = self._getat('coefficients', None)
421 intercepts = self._getat('intercepts', None)
422 post_transform = self._getat(
423 'post_transform', 'NONE').strip('"\'b')
424 classlabels_strings = self._getat('classlabels_strings', None)
425 if classlabels_strings is not None:
426 raise NotImplementedError( # pragma: no cover
427 "Conversion of operator %r with classlabels_strings=%r "
428 "is not implemented." % (self.op_type, classlabels_strings))
429 classlabels_ints = self._getat(
430 'classlabels_ints', None, format="listint")
431 if classlabels_ints != list(range(len(classlabels_ints))):
432 raise NotImplementedError( # pragma: no cover
433 "Conversion of operator %r with classlabels_ints=%r!=%r "
434 "is not implemented." % (
435 self.op_type, classlabels_ints,
436 list(range(len(classlabels_ints)))))
437 targets = len(classlabels_ints)
438 rows = [
439 "coefs = numpy.array(%s, dtype=numpy.float32)."
440 "reshape((%d, -1)).T" % (coefficients, targets),
441 "%sinter = numpy.array(%s, dtype=numpy.float32)."
442 "reshape((-1, %d))" % (self.indent, intercepts, targets)]
444 if post_transform == "SOFTMAX":
445 rows.append(
446 "%s%s = scipy_special.softmax"
447 "(%s @ coefs + inter, axis=1)" % (
448 self.indent, self.outputs[1], self.inputs[0]))
449 elif post_transform == 'NONE':
450 rows.append(
451 "%s%s = %s @ coefs + inter" % (
452 self.indent, self.outputs[1], self.inputs[0]))
453 elif post_transform != "NONE":
454 raise NotImplementedError( # pragma: no cover
455 "Conversion of operator %r with post_transform=%r "
456 "is not implemented." % (self.op_type, post_transform))
457 rows.append("%s%s = numpy.argmax(%s, axis=1)" % (
458 self.indent, self.outputs[0], self.outputs[1]))
459 return "\n".join(rows)
461 if self.op_type == 'LinearRegressor':
462 self._make_sure_inputs(1)
463 coefficients = self._getat('coefficients', None)
464 intercepts = self._getat('intercepts', None)
465 post_transform = self._getat(
466 'post_transform', 'NONE').strip('"\'b')
467 targets = self._getat('targets', 1)
468 if post_transform != "NONE":
469 raise NotImplementedError( # pragma: no cover
470 "Conversion of operator %r with post_transform=%r "
471 "is not implemented." % (self.op_type, post_transform))
472 rows = [
473 "coefs = numpy.array(%s, dtype=numpy.float32)."
474 "reshape((%d, -1)).T" % (coefficients, targets),
475 "%sinter = numpy.array(%s, dtype=numpy.float32)."
476 "reshape((-1, %d))" % (self.indent, intercepts, targets),
477 "%s%s = %s @ coefs + inter" % (
478 self.indent, outs, self.inputs[0])]
479 return "\n".join(rows)
481 if self.op_type == 'Normalizer':
482 self._make_sure_inputs(1)
483 post_transform = self._getat('norm', 'MAX').strip('"\'b')
484 if post_transform == 'L2':
485 return "%s = %s / (%s ** 2).sum(axis=1) ** 0.5" % (
486 outs, self.inputs[0], self.inputs[0])
487 if post_transform == 'L1':
488 post_transform = 'sum'
489 return "%s = %s / %s.%s(axis=1, keepdims=1)" % (
490 outs, self.inputs[0], self.inputs[0], post_transform.lower())
492 raise NotImplementedError( # pragma: no cover
493 "Unable to convert operator type %r name=%r (onnxml)." % (
494 self.op_type, self.name))
496 def _make_numpy_code_others(self):
497 outs = ", ".join(self.outputs)
499 if self.op_type == 'CDist':
500 self._make_sure_inputs(2)
501 metric = self._getat('metric', 'euclidean').strip("'b")
502 return "%s = scipy_distance.cdist(%s, %s, metric=%r)" % (
503 outs, self.inputs[0], self.inputs[1], metric)
505 raise NotImplementedError( # pragma: no cover
506 "Unable to convert operator type %r (domain=%r) "
507 "name=%r (onnxml)." % (
508 self.op_type, self.domain, self.name))
511def make_numpy_code(opset, name=None, op_type=None, domain='',
512 inputs=None, outputs=None, attributes=None,
513 used=None, context=None, mark_inits=None,
514 indent="", **unused):
515 """
516 Converts an ONNX operators into :epkg:`numpy` code.
518 :param opset: target opset for the conversion (usually unused)
519 :param name: node name
520 :param op_type: operator type
521 :param domain: domain
522 :param inputs: inputs
523 :param outputs: outputs
524 :param attributes: attributes
525 :param used: dictionary `{k: v}`,
526 list of nodes taking *k* as input
527 :param context: whole context
528 :param mark_inits: marks initializer as replaced
529 :param indent: indentation of the second line and following
530 :return: code as str
531 """
532 cl = NumpyCode(
533 opset=opset, name=name, op_type=op_type, domain=domain,
534 inputs=inputs, outputs=outputs, attributes=attributes,
535 used=used, context=context, mark_inits=mark_inits,
536 indent=indent, **unused)
537 return cl.make_numpy_code()