Coverage for mlprodict/onnxrt/ops_cpu/op_argmin.py: 95%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

40 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRunArg 

10 

11 

12def _argmin(data, axis=0, keepdims=True): 

13 result = numpy.argmin(data, axis=axis) 

14 if keepdims and len(result.shape) < len(data.shape): 

15 result = numpy.expand_dims(result, axis) 

16 return result.astype(numpy.int64) 

17 

18 

19def _argmin_use_numpy_select_last_index( 

20 data, axis=0, keepdims=True): 

21 data = numpy.flip(data, axis) 

22 result = numpy.argmin(data, axis=axis) 

23 result = data.shape[axis] - result - 1 

24 if keepdims: 

25 result = numpy.expand_dims(result, axis) 

26 return result.astype(numpy.int64) 

27 

28 

29class _ArgMin(OpRunArg): 

30 """ 

31 Base class for runtime for operator `ArgMin 

32 <https://github.com/onnx/onnx/blob/master/docs/ 

33 Operators.md#ArgMin>`_. 

34 """ 

35 

36 def __init__(self, onnx_node, desc=None, 

37 expected_attributes=None, **options): 

38 OpRunArg.__init__(self, onnx_node, desc=desc, 

39 expected_attributes=expected_attributes, 

40 **options) 

41 

42 def _run(self, data): # pylint: disable=W0221 

43 return (_argmin(data, axis=self.axis, keepdims=self.keepdims), ) 

44 

45 

46class ArgMin_11(_ArgMin): 

47 

48 atts = {'axis': 0, 'keepdims': 1} 

49 

50 def __init__(self, onnx_node, desc=None, **options): 

51 _ArgMin.__init__(self, onnx_node, desc=desc, 

52 expected_attributes=ArgMin_11.atts, 

53 **options) 

54 

55 def to_python(self, inputs): 

56 return ('import numpy\nfrom mlprodict.onnxrt.ops_cpu.op_argmin import _argmin', 

57 'return _argmin(%s, axis=axis, keepdims=keepdims)' % inputs[0]) 

58 

59 

60class ArgMin_12(_ArgMin): 

61 

62 atts = {'axis': 0, 'keepdims': 1, 'select_last_index': 0} 

63 

64 def __init__(self, onnx_node, desc=None, **options): 

65 _ArgMin.__init__(self, onnx_node, desc=desc, 

66 expected_attributes=ArgMin_12.atts, 

67 **options) 

68 

69 def _run(self, data): # pylint: disable=W0221 

70 if self.select_last_index == 0: 

71 return _ArgMin._run(self, data) 

72 return (_argmin_use_numpy_select_last_index( 

73 data, axis=self.axis, keepdims=self.keepdims), ) 

74 

75 def to_python(self, inputs): 

76 lines = [ 

77 "if select_last_index == 0:", 

78 " return _argmin({0}, axis=axis, keepdims=keepdims)", 

79 "return _argmin_use_numpy_select_last_index(", 

80 " {0}, axis=axis, keepdims=keepdims)"] 

81 return ('import numpy\nfrom mlprodict.onnxrt.ops_cpu.op_argmin import _argmin, _argmin_use_numpy_select_last_index', 

82 "\n".join(lines).format(inputs[0])) 

83 

84 

85if onnx_opset_version() >= 12: 

86 ArgMin = ArgMin_12 

87else: 

88 ArgMin = ArgMin_11 # pragma: no cover