Coverage for mlprodict/onnxrt/ops_cpu/_op.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *ops_cpu*.
5"""
6import pprint
7import numpy
8import onnx
9import onnx.defs
10from ..shape_object import ShapeObject
11from ..type_object import SequenceType
12from ._new_ops import OperatorSchema
15def _build_schemas():
16 res = {}
17 for schema in onnx.defs.get_all_schemas_with_history():
18 # Multiple version can coexist. The last one is kept.
19 if schema.name in res:
20 if schema.since_version > res[schema.name].since_version:
21 # We keep the most recent one.
22 res[schema.name] = schema
23 else:
24 res[schema.name] = schema
25 res[schema.name + '_' + str(schema.since_version)] = schema
26 return res
29_schemas = _build_schemas()
30_at_least_one = {'Constant'}
33class RuntimeTypeError(RuntimeError):
34 """
35 Raised when a type of a variable is unexpected.
36 """
37 pass
40class DefaultNone:
41 """
42 Default value for parameters when the parameter is not set
43 but the operator has a default behaviour for it.
44 """
45 pass
48class OpRun:
49 """
50 Ancestor to all operators in this subfolder.
51 The runtime for every node can checked into
52 `ONNX unit tests
53 <https://github.com/onnx/onnx/tree/master/onnx/backend/test/case/node>`_.
54 """
56 def __init__(self, onnx_node, desc=None, expected_attributes=None,
57 **options):
58 """
59 @param onnx_node :epkg:`onnx` node
60 @param desc internal representation
61 @param expected_attributes expected attributes for this node
62 @param options runtime options
63 """
64 self._provider = 'python'
65 self.onnx_node = onnx_node
66 self.desc = desc
67 self.inplaces = {}
69 if onnx_node.op_type in _schemas:
70 self._schema = _schemas[onnx_node.op_type]
71 else:
72 self._schema = self._find_custom_operator_schema(onnx_node.op_type)
73 if self._schema is None:
74 raise RuntimeError( # pragma: no cover
75 "Unable to find class name '{}' in available schemas:"
76 "(onnx.__version__='{}')\n{}".format(
77 self.__class__.__name__,
78 onnx.__version__,
79 "\n".join(sorted(_schemas))))
81 if desc is not None:
82 if 'atts' in desc:
83 for a, b in desc['atts'].items():
84 if not isinstance(b, dict) or 'value' not in b:
85 raise ValueError( # pragma: no cover
86 "Unexpected value {}.".format(b))
87 options[a] = (b['value_rt'] if 'value_rt' in b
88 else b['value'])
89 if expected_attributes is not None:
90 if onnx_node.op_type in _at_least_one:
91 done = 0
92 for a, b in expected_attributes.items():
93 if a in options:
94 setattr(self, a, b)
95 done += 1
96 if done == 0:
97 raise RuntimeError( # pragma: no cover
98 "All parameters '{}' are missing from operator '{}', "
99 "given {}.".format(
100 a, onnx_node.op_type, list(sorted(options))))
101 else:
102 for a, b in expected_attributes.items():
103 if a not in options:
104 if b is DefaultNone:
105 setattr(self, a, None)
106 elif b is None:
107 raise RuntimeError( # pragma: no cover
108 "Parameter '{}' is missing from operator '{}' "
109 "(class='{}'), given {}.".format(
110 a, onnx_node.op_type,
111 self.__class__.__name__,
112 list(sorted(options))))
113 else:
114 setattr(self, a, b)
115 for k, v in options.items():
116 setattr(self, k, v)
118 if onnx_node.op_type not in _at_least_one:
119 for k, v in self._schema.attributes.items():
120 if not hasattr(self, k) and getattr(v, 'required', True):
121 raise RuntimeError( # pragma: no cover
122 "Attribute '{}' is expected based on ONNX specifications "
123 "for node '{}' and options {}.".format(
124 k, onnx_node.op_type, pprint.pformat(options)))
126 def need_context(self):
127 """
128 Tells the runtime if this node needs the context
129 (all the results produced so far) as it may silently access
130 one of them (operator Loop).
131 The default answer is `False`.
132 """
133 return False
135 def _find_custom_operator_schema(self, op_name):
136 raise NotImplementedError( # pragma: no cover
137 "This method should be overwritten for operator "
138 "'{}'.".format(op_name))
140 def __str__(self):
141 """
142 usual
143 """
144 atts = [self.__class__.__name__ + '(',
145 " op_type={}".format(self.onnx_node.op_type)]
146 for k, v in sorted(self.__dict__.items()):
147 if k in {'desc', 'onnx_node'}:
148 continue
149 if 'a' <= k[0] <= 'z' and k[-1] != '_':
150 atts.append(' {0}={1},'.format(k, v))
151 atts.append(')')
152 return "\n".join(atts)
154 def _run(self, *args, **kwargs):
155 """
156 Should be overwritten.
157 """
158 raise NotImplementedError( # pragma: no cover
159 "Method '_run' or 'to_python' should be overwritten for operator %s."
160 "" % self.__class__.__name__)
162 def run(self, *args, **kwargs): # pylint: disable=E0202
163 """
164 Calls method ``_run``.
165 """
166 try:
167 res = self._run(*args, **kwargs)
168 except TypeError as e:
169 raise TypeError( # pragma: no cover
170 "Issues with types {} (operator {}).".format(
171 ", ".join(str(type(_)) for _ in args),
172 self.__class__.__name__)) from e
173 return res
175 def switch_initializers_dtype(self, dtype_in=numpy.float32,
176 dtype_out=numpy.float64):
177 """
178 Switches all initializers to ``numpy.float64``. If *model*
179 is None, a simple cast is done.
181 @param dtype_in previous type
182 @param dtype_out next type
183 @return done operations
184 """
185 done = []
186 for k, v in sorted(self.__dict__.items()):
187 if k in {'desc', 'onnx_node'}:
188 continue
189 if isinstance(v, numpy.ndarray):
190 if v.dtype == dtype_in:
191 v = v.astype(dtype_out)
192 setattr(self, k, v)
193 done.append(("+", "att", k, getattr(self, k)))
194 else:
195 done.append(("-", "att", k, getattr(self, k)))
196 if hasattr(self, '_run_no_checks_') and hasattr(self, 'run'):
197 self.run = self._run_no_checks_ # pylint: disable=E0202,E1101
198 return done
200 def infer_shapes(self, *args, **kwargs):
201 """
202 Infer shapes of the outputs given the shapes
203 of the inputs. It works the same way as method *run*.
204 """
205 try:
206 res = self._infer_shapes(*args, **kwargs)
207 except TypeError as e:
208 raise TypeError( # pragma: no cover
209 "Issues with (operator '{}') and shapes\n{}"
210 "\n----args\n{}\n------kwargs\n{}".format(
211 self.__class__.__name__,
212 "\n".join(str(_) for _ in args),
213 pprint.pformat(args),
214 pprint.pformat(kwargs))) from e
215 if res is None:
216 return res
217 if not isinstance(res, tuple):
218 raise TypeError( # pragma: no cover
219 "res must be tuple not {} (operator '{}')".format(
220 type(res), self.__class__.__name__))
221 for a in res:
222 if not isinstance(a, ShapeObject):
223 raise TypeError( # pragma: no cover
224 "One shape is not a ShapeObject but {} (operator '{}')".format(
225 type(a), self.__class__.__name__))
226 return res
228 def _infer_shapes(self, *args, **kwargs):
229 """
230 Should be overwritten.
231 """
232 raise NotImplementedError(
233 "This method should be overwritten for operator '{}'.".format(
234 self.__class__.__name__)) # pragma: no cover
236 def infer_types(self, *args, **kwargs):
237 """
238 Infer types of the outputs given the types
239 of the inputs. It works the same way as method *run*.
240 """
241 try:
242 res = self._infer_types(*args, **kwargs)
243 except TypeError as e: # pragma: no cover
244 raise TypeError(
245 "Issues with (operator '{}') and types\n{}"
246 "\n----args\n{}\n------kwargs\n{}".format(
247 self.__class__.__name__,
248 "\n".join(str(_) for _ in args),
249 pprint.pformat(args),
250 pprint.pformat(kwargs))) from e
251 if not isinstance(res, tuple):
252 raise TypeError( # pragma: no cover
253 "res must be tuple not {} (operator '{}')".format(
254 type(res), self.__class__.__name__))
255 for a in res:
256 if not isinstance(a, (numpy.dtype, SequenceType)) and a not in {
257 numpy.int8, numpy.uint8, numpy.float16, numpy.float32,
258 numpy.float64, numpy.int32, numpy.int64, numpy.int16,
259 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_,
260 numpy.uint64, bool, str}:
261 raise TypeError( # pragma: no cover
262 "Type ({}, {}) is not a numpy type or a sequence type "
263 "(operator '{}')".format(
264 a, type(a), self.__class__.__name__))
265 return res
267 def _infer_types(self, *args, **kwargs):
268 """
269 Should be overwritten.
270 """
271 raise NotImplementedError(
272 "This method should be overwritten for operator '{}'.".format(
273 self.__class__.__name__)) # pragma: no cover
275 def infer_sizes(self, *args, **kwargs):
276 """
277 Infer sizes required for computation.
278 It works the same way as method *run*.
279 """
280 try:
281 res = self._infer_sizes(*args, **kwargs)
282 except TypeError as e: # pragma: no cover
283 raise TypeError(
284 "Issues with (operator '{}') and types\n{}"
285 "\n----args\n{}\n------kwargs\n{}".format(
286 self.__class__.__name__,
287 "\n".join(str(_) for _ in args),
288 pprint.pformat(args),
289 pprint.pformat(kwargs))) from e
290 if not isinstance(res, tuple):
291 raise TypeError( # pragma: no cover
292 "res must be dict not {} (operator '{}')".format(
293 type(res), self.__class__.__name__))
294 return res
296 def _infer_sizes(self, *args, **kwargs):
297 """
298 Should be overwritten.
299 """
300 raise NotImplementedError(
301 "This method should be overwritten for operator '{}'.".format(
302 self.__class__.__name__)) # pragma: no cover
304 def enable_inplace_compute(self, index):
305 """
306 Tells the node that one input can be overwritten.
308 @param index input index
309 """
310 self.inplaces[index] = True
312 @property
313 def args_default(self):
314 """
315 Returns the list of arguments as well as
316 the list of parameters with the default values
317 (close to the signature).
318 """
319 inps = []
320 if hasattr(self, 'atts'):
321 for k, v in self.atts.items(): # pylint: disable=E1101
322 if isinstance(v, (list, tuple, dict)) and len(v) == 0:
323 v = None
324 inps.append('%s=%r' % (k, v))
325 return inps
327 @property
328 def args_default_modified(self):
329 """
330 Returns the list of modified parameters.
331 """
332 if not hasattr(self, 'atts'):
333 return None
335 inps = []
336 for k, v in self.atts.items(): # pylint: disable=E1101
337 val = getattr(self, k, None)
338 if isinstance(val, numpy.ndarray) and isinstance(v, list):
339 val = list(val)
340 try:
341 if val != v:
342 inps.append('%s=%r' % (k, val))
343 except ValueError as e: # pragma: no cover
344 raise ValueError(
345 "Unexpected value for v=%r and val=%r." % (v, val)) from e
346 return inps
348 @property
349 def args_optional(self):
350 """
351 Returns the list of optional arguments.
352 """
353 inps = []
354 if hasattr(self, 'optional_inputs'):
355 for k, v in self.optional_inputs.items(): # pylint: disable=E1101
356 inps.append('%s=%r' % (k, v))
357 return inps
359 @property
360 def args_mandatory(self):
361 """
362 Returns the list of optional arguments.
363 """
364 if hasattr(self, 'mandatory_inputs'):
365 return self.mandatory_inputs # pylint: disable=E1101
366 return None
368 def to_python(self, inputs):
369 """
370 Returns a python code equivalent to this operator.
372 @param inputs inputs name
373 @return imports, python code, both as strings
374 """
375 raise NotImplementedError(
376 "Operator '{}' has no equivalent python code.".format(self.__class__.__name__)) # pragma: no cover
378 def _to_python_numpy(self, inputs, numpy_name):
379 return ("import numpy",
380 "return numpy.%s(%s)" % (numpy_name, ", ".join(inputs)))
382 @property
383 def atts_value(self):
384 "Returns all parameters in a dictionary."
385 if hasattr(self, 'atts'):
386 return {k: getattr(self, k)
387 for k in self.atts} # pylint: disable=E1101
388 return None
391class OpRunUnary(OpRun):
392 """
393 Ancestor to all unary operators in this subfolder.
394 Checks that inputs type are the same.
395 """
397 def __init__(self, onnx_node, desc=None, expected_attributes=None,
398 **options):
399 OpRun.__init__(self, onnx_node, desc=desc,
400 expected_attributes=expected_attributes,
401 **options)
403 def run(self, x): # pylint: disable=E0202,W0221
404 """
405 Calls method ``_run``.
406 """
407 try:
408 res = self._run(x)
409 except TypeError as e:
410 raise TypeError( # pragma: no cover
411 "Issues with types {} (binary operator {}).".format(
412 ", ".join(str(type(_)) for _ in [x]),
413 self.__class__.__name__)) from e
414 return res
416 def infer_shapes(self, x): # pylint: disable=E0202,W0221
417 try:
418 return self._infer_shapes(x)
419 except TypeError as e: # pragma: no cover
420 raise TypeError(
421 "Issues with types {} (operator {}).".format(
422 x.dtype, self.__class__.__name__)) from e
424 def _infer_shapes(self, x): # pylint: disable=E0202,W0221
425 """
426 Returns the same shape by default.
427 """
428 return (x, )
430 def infer_types(self, x): # pylint: disable=E0202,W0221
431 try:
432 return self._infer_types(x)
433 except TypeError as e: # pragma: no cover
434 raise TypeError(
435 "Issues with types {} (operator {}).".format(
436 x, self.__class__.__name__)) from e
438 def _infer_types(self, x): # pylint: disable=E0202,W0221
439 """
440 Returns the same type by default.
441 """
442 return (x, )
444 def _infer_sizes(self, *args, **kwargs):
445 res = self.run(*args, **kwargs)
446 return (dict(temp=0), ) + res
449class OpRunArg(OpRunUnary):
450 """
451 Ancestor to all unary operators in this subfolder
452 and which produces position of extremas (ArgMax, ...).
453 Checks that inputs type are the same.
454 The class must have attributes *axis*, *keepdim*.
455 """
457 def __init__(self, onnx_node, desc=None, expected_attributes=None,
458 **options):
459 OpRunUnary.__init__(self, onnx_node, desc=desc,
460 expected_attributes=expected_attributes,
461 **options)
462 if not hasattr(self, 'keepdims'):
463 raise AttributeError( # pragma: no cover
464 "Attribute 'keepdims' is missing.")
465 if not hasattr(self, 'axis'):
466 raise AttributeError( # pragma: no cover
467 "Attribute 'axis' is missing.")
469 def run(self, x): # pylint: disable=E0202
470 """
471 Calls method ``_run``.
472 """
473 res = OpRunUnary.run(self, x)
474 if res[0].dtype != numpy.int64:
475 raise RuntimeTypeError( # pragma: no cover
476 "Output type mismatch: should be '{}' != output '{}' "
477 "(operator '{}')".format(
478 numpy.int64, res[0].dtype, self.__class__.__name__))
479 return res
481 def _infer_shapes(self, x): # pylint: disable=W0221
482 sh = x.reduce(self.axis, self.keepdims, # pylint: disable=E1101
483 dtype=numpy.int64) # pylint: disable=E1101
484 return (sh, )
486 def _infer_types(self, x): # pylint: disable=W0221
487 return (numpy.int64, )
489 def _run_no_checks_(self, x): # pylint: disable=W0221
490 return OpRunUnary.run(self, x)
493class OpRunUnaryNum(OpRunUnary):
494 """
495 Ancestor to all unary and numerical operators
496 in this subfolder. Checks that inputs type
497 are the same.
498 """
500 def __init__(self, onnx_node, desc=None, expected_attributes=None,
501 **options):
502 OpRunUnary.__init__(self, onnx_node, desc=desc,
503 expected_attributes=expected_attributes,
504 **options)
506 def run(self, x): # pylint: disable=E0202
507 """
508 Calls method ``_run``.
509 """
510 res = OpRunUnary.run(self, x)
511 if len(res) == 0 or res[0] is None:
512 return res
513 if not isinstance(res[0], list) and res[0].dtype != x.dtype:
514 raise RuntimeTypeError( # pragma: no cover
515 "Output type mismatch: input '{}' != output '{}' "
516 "(operator '{}')".format(
517 x.dtype, res[0].dtype, self.__class__.__name__))
518 return res
520 def _run_no_checks_(self, x): # pylint: disable=W0221
521 return OpRunUnary.run(self, x)
524class OpRunClassifierProb(OpRunUnary):
525 """
526 Ancestor to all binary operators in this subfolder.
527 Checks that inputs type are the same.
528 """
530 def __init__(self, onnx_node, desc=None, expected_attributes=None,
531 **options):
532 OpRunUnary.__init__(self, onnx_node, desc=desc,
533 expected_attributes=expected_attributes,
534 **options)
536 def run(self, x): # pylint: disable=E0202
537 """
538 Calls method ``_run``.
539 """
540 res = OpRunUnary.run(self, x)
541 if x.dtype in (numpy.float32, numpy.float64) and res[1].dtype != x.dtype:
542 raise RuntimeTypeError( # pragma: no cover
543 "Output type mismatch: {} != {} (operator '{}')".format(
544 x.dtype, res[1].dtype, self.__class__.__name__))
545 return res
547 @property
548 def nb_classes(self):
549 """
550 Returns the number of expected classes.
551 """
552 return max(len(getattr(self, 'classlabels_ints', [])),
553 len(getattr(self, 'classlabels_int64s', [])),
554 len(self.classlabels_strings)) # pylint: disable=E1101
556 def _run_no_checks_(self, x): # pylint: disable=W0221
557 return OpRunUnary.run(self, x)
559 def _infer_shapes(self, x): # pylint: disable=W0221
560 """
561 Returns the same for the labels and the probabilities.
562 """
563 return (ShapeObject((x[0], ), dtype=numpy.int64,
564 name="{}-0".format(self.__class__.__name__)),
565 ShapeObject((x[0], self.nb_classes), dtype=x.dtype,
566 name="{}-1".format(self.__class__.__name__)))
568 def _infer_types(self, x): # pylint: disable=W0221
569 """
570 Returns the type of the labels and the probabilities.
571 """
572 return (numpy.int64, x.dtype)
575class OpRunBinary(OpRun):
576 """
577 Ancestor to all binary operators in this subfolder.
578 Checks that inputs type are the same.
579 """
581 def __init__(self, onnx_node, desc=None, expected_attributes=None,
582 **options):
583 OpRun.__init__(self, onnx_node, desc=desc,
584 expected_attributes=expected_attributes,
585 **options)
587 def run(self, x, y): # pylint: disable=E0202,W0221
588 """
589 Calls method ``_run``.
590 """
591 if x is None or y is None:
592 raise RuntimeError( # pragma: no cover
593 "x and y have different dtype: {} != {} ({})".format(
594 type(x), type(y), type(self)))
595 if x.dtype != y.dtype:
596 raise RuntimeTypeError(
597 "Input type mismatch: {} != {} (operator '{}', shapes {}, {})".format(
598 x.dtype, y.dtype, self.__class__.__name__,
599 x.shape, y.shape))
600 try:
601 res = self._run(x, y)
602 except (TypeError, ValueError) as e: # pragma: no cover
603 raise TypeError(
604 "Issues with types {} (binary operator {}).".format(
605 ", ".join(str(type(_)) for _ in [x, y]),
606 self.__class__.__name__)) from e
607 return res
609 def _run_no_checks_(self, x, y): # pylint: disable=W0221
610 """
611 Calls method ``_run``.
612 """
613 try:
614 res = self._run(x, y)
615 except TypeError as e: # pragma: no cover
616 raise TypeError(
617 "Issues with types {} (binary operator {}).".format(
618 ", ".join(str(type(_)) for _ in [x, y]),
619 self.__class__.__name__)) from e
620 return res
622 def _infer_shapes(self, x, y): # pylint: disable=W0221
623 """
624 Returns the same shape by default.
625 We assume the operator returns the biggest
626 shapes as the operator could be using broacasting.
627 """
628 if x is None or y is None:
629 return None
630 try:
631 res = x.broadcast(y)
632 add = "broadcast"
633 except RuntimeError: # pragma: no cover
634 # We know x and y and the same number of dimensions.
635 # We pick the first one even if it might be wrong.
636 res = x
637 add = "1"
638 if res.name is None:
639 return (res.copy(name="{}{}".format(
640 self.__class__.__name__, add)), )
641 return (res.copy(name="{}-{}{}".format(
642 res.name, self.__class__.__name__, add)), )
644 def _infer_types(self, x, y): # pylint: disable=W0221
645 """
646 Returns the boolean type.
647 """
648 return (x, )
650 def _infer_sizes(self, *args, **kwargs):
651 res = self.run(*args, **kwargs)
652 return (dict(temp=0), ) + res
655class OpRunBinaryComparison(OpRunBinary):
656 """
657 Ancestor to all binary operators in this subfolder
658 comparing tensors.
659 """
661 def __init__(self, onnx_node, desc=None, expected_attributes=None,
662 **options):
663 OpRunBinary.__init__(self, onnx_node, desc=desc,
664 expected_attributes=expected_attributes,
665 **options)
667 def _infer_types(self, x, y): # pylint: disable=W0221
668 return (numpy.bool_, )
671class OpRunBinaryNum(OpRunBinary):
672 """
673 Ancestor to all binary operators in this subfolder.
674 Checks that inputs type are the same.
675 """
677 def __init__(self, onnx_node, desc=None, expected_attributes=None,
678 **options):
679 OpRunBinary.__init__(self, onnx_node, desc=desc,
680 expected_attributes=expected_attributes,
681 **options)
683 def run(self, x, y): # pylint: disable=E0202
684 """
685 Calls method ``_run``.
686 """
687 res = OpRunBinary.run(self, x, y)
688 if res[0].dtype != x.dtype:
689 raise RuntimeTypeError(
690 "Output type mismatch: {} != {} or {} (operator '{}')"
691 " type(x)={} type(y)={}".format(
692 x.dtype, res[0].dtype, y.dtype,
693 self.__class__.__name__, type(x), type(y)))
694 return res
696 def _run_no_checks_(self, x, y): # pylint: disable=W0221
697 """
698 Calls method ``_run``.
699 """
700 return OpRunBinary._run_no_checks_(self, x, y)
703class OpRunBinaryNumpy(OpRunBinaryNum):
704 """
705 Implements the inplaces logic.
706 *numpy_fct* is a binary numpy function which
707 takes two matrices and has a argument *out*
708 for inplace operations.
709 """
711 def __init__(self, numpy_fct, onnx_node, desc=None,
712 expected_attributes=None, **options):
713 OpRunBinaryNum.__init__(self, onnx_node, desc=desc,
714 expected_attributes=expected_attributes,
715 **options)
716 self.numpy_fct = numpy_fct
717 self._cannot_inplace_int = self.numpy_fct in (
718 numpy.divide, numpy.true_divide)
720 def _run(self, a, b): # pylint: disable=W0221
721 if (self._cannot_inplace_int and
722 numpy.issubdtype(a.dtype, numpy.integer)):
723 return (self.numpy_fct(a, b), )
724 if self.inplaces.get(0, False) and a.size >= b.size:
725 if len(a.shape) == 1 and b.shape == (1, 1):
726 a = a.reshape(1, a.shape[0])
727 try:
728 self.numpy_fct(a, b, out=a)
729 return (a, )
730 except (ValueError, TypeError):
731 return (self.numpy_fct(a, b), )
732 if self.inplaces.get(1, False) and a.size <= b.size:
733 if len(b.shape) == 1 and a.shape == (1, 1):
734 b = b.reshape(b.shape[0], 1)
735 try:
736 self.numpy_fct(a, b, out=b)
737 return (b, )
738 except (ValueError, TypeError):
739 return (self.numpy_fct(a, b), )
740 return (self.numpy_fct(a, b), )
742 def to_python(self, inputs):
743 """
744 Returns a python code equivalent to this operator.
746 @param inputs inputs name
747 @return imports, python code, both as strings
748 """
749 lines = [
750 "# inplaces not take into account {}-{}".format(
751 self.inplaces.get(0, False), self.inplaces.get(1, False)),
752 "return numpy.{0}({1})".format(
753 self.numpy_fct.__name__, ', '.join(inputs))
754 ]
755 return "import numpy", "\n".join(lines)
758class OpRunReduceNumpy(OpRunUnaryNum):
759 """
760 Implements the reduce logic.
761 It must have a parameter *axes*.
762 """
764 def __init__(self, onnx_node, desc=None,
765 expected_attributes=None, **options):
766 if ('noop_with_empty_axes' not in expected_attributes and
767 'axes' not in expected_attributes):
768 raise RuntimeError( # pragma: no cover
769 "Parameter 'axes' is expected but not found in {} "
770 "from class {}".format(expected_attributes, type(self)))
771 if (expected_attributes.get('noop_with_empty_axes', 0) and
772 (expected_attributes['axes'] is None or
773 len(expected_attributes['axes']) == 0)):
774 raise RuntimeError( # pragma: no cover
775 "Parameter 'axes' cannot be empty as {} (noop_with_empty_axes=1) "
776 "from class {}".format(expected_attributes, type(self)))
777 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
778 expected_attributes=expected_attributes,
779 **options)
780 if isinstance(self.axes, numpy.ndarray): # pylint: disable=E0203
781 if (len(self.axes.shape) == 0 or # pylint: disable=E0203
782 self.axes.shape[0] == 0): # pylint: disable=E0203
783 self.axes = None
784 else:
785 self.axes = tuple(self.axes)
786 elif self.axes in [[], tuple()]: # pylint: disable=E0203
787 self.axes = None
788 elif isinstance(self.axes, list): # pylint: disable=E0203
789 self.axes = tuple(self.axes)
792class OpRunCustom(OpRun):
793 """
794 Automates some methods for custom operators defined
795 outside *mlprodict*.
796 """
798 class OpRunCustomSchema(OperatorSchema):
799 """
800 Custom schema.
801 """
803 def __init__(self, cls):
804 OperatorSchema.__init__(self, cls.__name__)
805 self.attributes = cls.atts
807 def __init__(self, onnx_node, desc=None,
808 expected_attributes=None, **options):
809 OpRun.__init__(self, onnx_node, desc=desc,
810 expected_attributes=expected_attributes,
811 **options)
813 def _find_custom_operator_schema(self, op_name):
814 """
815 Finds a custom operator defined by this runtime.
816 """
817 if (op_name == self.__class__.__name__ or
818 (hasattr(self.__class__, 'op_name') and
819 self.__class__.op_name == op_name)): # pylint: disable=E1101
820 return OpRunCustom.OpRunCustomSchema(self.__class__)
821 raise RuntimeError( # pragma: no cover
822 "Unable to find a schema for operator '{}'.".format(op_name))