Coverage for mlprodict/onnxrt/ops_cpu/op_argmin.py: 95%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRunArg
12def _argmin(data, axis=0, keepdims=True):
13 result = numpy.argmin(data, axis=axis)
14 if keepdims and len(result.shape) < len(data.shape):
15 result = numpy.expand_dims(result, axis)
16 return result.astype(numpy.int64)
19def _argmin_use_numpy_select_last_index(
20 data, axis=0, keepdims=True):
21 data = numpy.flip(data, axis)
22 result = numpy.argmin(data, axis=axis)
23 result = data.shape[axis] - result - 1
24 if keepdims:
25 result = numpy.expand_dims(result, axis)
26 return result.astype(numpy.int64)
29class _ArgMin(OpRunArg):
30 """
31 Base class for runtime for operator `ArgMin
32 <https://github.com/onnx/onnx/blob/master/docs/
33 Operators.md#ArgMin>`_.
34 """
36 def __init__(self, onnx_node, desc=None,
37 expected_attributes=None, **options):
38 OpRunArg.__init__(self, onnx_node, desc=desc,
39 expected_attributes=expected_attributes,
40 **options)
42 def _run(self, data): # pylint: disable=W0221
43 return (_argmin(data, axis=self.axis, keepdims=self.keepdims), )
46class ArgMin_11(_ArgMin):
48 atts = {'axis': 0, 'keepdims': 1}
50 def __init__(self, onnx_node, desc=None, **options):
51 _ArgMin.__init__(self, onnx_node, desc=desc,
52 expected_attributes=ArgMin_11.atts,
53 **options)
55 def to_python(self, inputs):
56 return ('import numpy\nfrom mlprodict.onnxrt.ops_cpu.op_argmin import _argmin',
57 'return _argmin(%s, axis=axis, keepdims=keepdims)' % inputs[0])
60class ArgMin_12(_ArgMin):
62 atts = {'axis': 0, 'keepdims': 1, 'select_last_index': 0}
64 def __init__(self, onnx_node, desc=None, **options):
65 _ArgMin.__init__(self, onnx_node, desc=desc,
66 expected_attributes=ArgMin_12.atts,
67 **options)
69 def _run(self, data): # pylint: disable=W0221
70 if self.select_last_index == 0:
71 return _ArgMin._run(self, data)
72 return (_argmin_use_numpy_select_last_index(
73 data, axis=self.axis, keepdims=self.keepdims), )
75 def to_python(self, inputs):
76 lines = [
77 "if select_last_index == 0:",
78 " return _argmin({0}, axis=axis, keepdims=keepdims)",
79 "return _argmin_use_numpy_select_last_index(",
80 " {0}, axis=axis, keepdims=keepdims)"]
81 return ('import numpy\nfrom mlprodict.onnxrt.ops_cpu.op_argmin import _argmin, _argmin_use_numpy_select_last_index',
82 "\n".join(lines).format(inputs[0]))
85if onnx_opset_version() >= 12:
86 ArgMin = ArgMin_12
87else:
88 ArgMin = ArgMin_11 # pragma: no cover