Coverage for mlprodict/onnxrt/ops_cpu/op_hardmax.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

14 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRunUnaryNum 

9 

10 

11class Hardmax(OpRunUnaryNum): 

12 

13 atts = {'axis': -1} 

14 

15 def __init__(self, onnx_node, desc=None, **options): 

16 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

17 expected_attributes=Hardmax.atts, 

18 **options) 

19 

20 def _run(self, x): # pylint: disable=W0221 

21 x_argmax = numpy.argmax(x, axis=self.axis) 

22 y = numpy.zeros_like(x) 

23 numpy.put_along_axis(y, numpy.expand_dims(x_argmax, axis=self.axis), 

24 1, axis=self.axis) 

25 return (y, ) 

26 

27 def to_python(self, inputs): 

28 return ("import numpy", 

29 "\n".join([ 

30 "{0}_argmax = numpy.argmax({0}, axis=axis)".format( 

31 inputs[0]), 

32 "{0}y = numpy.zeros_like({0})".format(inputs[0]), 

33 "numpy.put_along_axis({0}y,".format(inputs[0]), 

34 " numpy.expand_dims(", 

35 " {0}_argmax, axis=axis),".format(inputs[0]), 

36 " 1, axis=axis)", 

37 "return {0}y".format(inputs[0])]))