Coverage for mlprodict/onnxrt/ops_cpu/op_argmax.py: 95%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRunArg
12def _argmax(data, axis=0, keepdims=True):
13 result = numpy.argmax(data, axis=axis)
14 if keepdims and len(result.shape) < len(data.shape):
15 result = numpy.expand_dims(result, axis)
16 return result.astype(numpy.int64)
19def _argmax_use_numpy_select_last_index(
20 data, axis=0, keepdims=True):
21 data = numpy.flip(data, axis)
22 result = numpy.argmax(data, axis=axis)
23 result = data.shape[axis] - result - 1
24 if keepdims and len(result.shape) < len(data.shape):
25 result = numpy.expand_dims(result, axis)
26 return result.astype(numpy.int64)
29class _ArgMax(OpRunArg):
30 """
31 Base class for runtime for operator `ArgMax
32 <https://github.com/onnx/onnx/blob/master/docs/
33 Operators.md#ArgMax>`_.
34 """
36 def __init__(self, onnx_node, desc=None,
37 expected_attributes=None, **options):
38 OpRunArg.__init__(self, onnx_node, desc=desc,
39 expected_attributes=expected_attributes,
40 **options)
42 def _run(self, data): # pylint: disable=W0221
43 return (_argmax(data, axis=self.axis, keepdims=self.keepdims), )
45 def to_python(self, inputs):
46 return ('import numpy\nfrom mlprodict.onnxrt.ops_cpu.op_argmax import _argmax',
47 'return _argmax(%s, axis=axis, keepdims=keepdims)' % inputs[0])
50class ArgMax_11(_ArgMax):
52 atts = {'axis': 0, 'keepdims': 1}
54 def __init__(self, onnx_node, desc=None, **options):
55 _ArgMax.__init__(self, onnx_node, desc=desc,
56 expected_attributes=ArgMax_11.atts,
57 **options)
60class ArgMax_12(_ArgMax):
62 atts = {'axis': 0, 'keepdims': 1, 'select_last_index': 0}
64 def __init__(self, onnx_node, desc=None, **options):
65 _ArgMax.__init__(self, onnx_node, desc=desc,
66 expected_attributes=ArgMax_12.atts,
67 **options)
69 def _run(self, data): # pylint: disable=W0221
70 if self.select_last_index == 0:
71 return _ArgMax._run(self, data)
72 return (_argmax_use_numpy_select_last_index(
73 data, axis=self.axis, keepdims=self.keepdims), )
75 def to_python(self, inputs):
76 lines = [
77 "if select_last_index == 0:",
78 " return _argmax({0}, axis=axis, keepdims=keepdims)",
79 "return _argmax_use_numpy_select_last_index(",
80 " {0}, axis=axis, keepdims=keepdims)"]
81 return ('import numpy\nfrom mlprodict.onnxrt.ops_cpu.op_argmax import _argmax, _argmax_use_numpy_select_last_index',
82 "\n".join(lines).format(inputs[0]))
85if onnx_opset_version() >= 12:
86 ArgMax = ArgMax_12
87else:
88 ArgMax = ArgMax_11 # pragma: no cover