Coverage for mlprodict/onnxrt/ops_cpu/op_normalizer.py: 93%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

42 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRunUnaryNum 

9 

10 

11class Normalizer(OpRunUnaryNum): 

12 

13 atts = {'norm': 'MAX'} 

14 

15 def __init__(self, onnx_node, desc=None, **options): 

16 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

17 expected_attributes=Normalizer.atts, 

18 **options) 

19 if self.norm == b'MAX': # pylint: disable=E1101 

20 self._norm = Normalizer.norm_max 

21 elif self.norm == b'L1': # pylint: disable=E1101 

22 self._norm = Normalizer.norm_l1 

23 elif self.norm == b'L2': # pylint: disable=E1101 

24 self._norm = Normalizer.norm_l2 

25 else: 

26 raise ValueError( # pragma: no cover 

27 "Unexpected value for norm='{}'.".format(self.norm)) # pylint: disable=E1101 

28 

29 @staticmethod 

30 def norm_max(x, inplace): 

31 "max normalization" 

32 if inplace: 

33 return Normalizer._norm_max_inplace(x) 

34 return x / numpy.abs(x).max(axis=1).reshape((x.shape[0], -1)) 

35 

36 @staticmethod 

37 def _norm_max_inplace(x): 

38 numpy.divide(x, numpy.abs(x).max(axis=1).reshape((x.shape[0], -1)), 

39 out=x) 

40 return x 

41 

42 @staticmethod 

43 def norm_l1(x, inplace): 

44 "L1 normalization" 

45 if inplace: 

46 return Normalizer._norm_L1_inplace(x) 

47 return x / numpy.abs(x).sum(axis=1).reshape((x.shape[0], -1)) 

48 

49 @staticmethod 

50 def _norm_L1_inplace(x): 

51 numpy.divide(x, numpy.abs(x).sum(axis=1).reshape((x.shape[0], -1)), 

52 out=x) 

53 return x 

54 

55 @staticmethod 

56 def norm_l2(x, inplace): 

57 "L2 normalization" 

58 xn = numpy.square(x).sum(axis=1) 

59 numpy.sqrt(xn, out=xn) 

60 norm = xn.reshape((x.shape[0], -1)) 

61 if inplace: 

62 numpy.divide(x, norm, out=x) 

63 return x 

64 return x / norm 

65 

66 def _run(self, x): # pylint: disable=W0221 

67 return (self._norm(x, inplace=self.inplaces.get(0, False)), )