Coverage for mlprodict/onnxrt/ops_cpu/op_normalizer.py: 93%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRunUnaryNum
11class Normalizer(OpRunUnaryNum):
13 atts = {'norm': 'MAX'}
15 def __init__(self, onnx_node, desc=None, **options):
16 OpRunUnaryNum.__init__(self, onnx_node, desc=desc,
17 expected_attributes=Normalizer.atts,
18 **options)
19 if self.norm == b'MAX': # pylint: disable=E1101
20 self._norm = Normalizer.norm_max
21 elif self.norm == b'L1': # pylint: disable=E1101
22 self._norm = Normalizer.norm_l1
23 elif self.norm == b'L2': # pylint: disable=E1101
24 self._norm = Normalizer.norm_l2
25 else:
26 raise ValueError( # pragma: no cover
27 "Unexpected value for norm='{}'.".format(self.norm)) # pylint: disable=E1101
29 @staticmethod
30 def norm_max(x, inplace):
31 "max normalization"
32 if inplace:
33 return Normalizer._norm_max_inplace(x)
34 return x / numpy.abs(x).max(axis=1).reshape((x.shape[0], -1))
36 @staticmethod
37 def _norm_max_inplace(x):
38 numpy.divide(x, numpy.abs(x).max(axis=1).reshape((x.shape[0], -1)),
39 out=x)
40 return x
42 @staticmethod
43 def norm_l1(x, inplace):
44 "L1 normalization"
45 if inplace:
46 return Normalizer._norm_L1_inplace(x)
47 return x / numpy.abs(x).sum(axis=1).reshape((x.shape[0], -1))
49 @staticmethod
50 def _norm_L1_inplace(x):
51 numpy.divide(x, numpy.abs(x).sum(axis=1).reshape((x.shape[0], -1)),
52 out=x)
53 return x
55 @staticmethod
56 def norm_l2(x, inplace):
57 "L2 normalization"
58 xn = numpy.square(x).sum(axis=1)
59 numpy.sqrt(xn, out=xn)
60 norm = xn.reshape((x.shape[0], -1))
61 if inplace:
62 numpy.divide(x, norm, out=x)
63 return x
64 return x / norm
66 def _run(self, x): # pylint: disable=W0221
67 return (self._norm(x, inplace=self.inplaces.get(0, False)), )