Coverage for mlprodict/onnxrt/ops_cpu/op_string_normalizer.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import unicodedata
8import locale
9import warnings
10import numpy
11from ._op import OpRunUnary, RuntimeTypeError
14class StringNormalizer(OpRunUnary):
15 """
16 The operator is not really threadsafe as python cannot
17 play with two locales at the same time. stop words
18 should not be implemented here as the tokenization
19 usually happens after this steps.
20 """
22 atts = {'case_change_action': b'NONE', # LOWER UPPER NONE
23 'is_case_sensitive': 1,
24 'locale': b'',
25 'stopwords': []}
27 def __init__(self, onnx_node, desc=None, **options):
28 OpRunUnary.__init__(self, onnx_node, desc=desc,
29 expected_attributes=StringNormalizer.atts,
30 **options)
31 self.slocale = self.locale.decode('ascii')
32 self.stops = set(self.stopwords)
34 def _run(self, x): # pylint: disable=W0221
35 """
36 Normalizes strings.
37 """
38 res = numpy.empty(x.shape, dtype=x.dtype)
39 if len(x.shape) == 2:
40 for i in range(0, x.shape[1]):
41 self._run_column(x[:, i], res[:, i])
42 elif len(x.shape) == 1:
43 self._run_column(x, res)
44 else:
45 raise RuntimeTypeError( # pragma: no cover
46 "x must be a matrix or a vector.")
47 return (res, )
49 def _run_column(self, cin, cout):
50 """
51 Normalizes string in a columns.
52 """
53 if locale.getlocale() != self.slocale:
54 try:
55 locale.setlocale(locale.LC_ALL, self.slocale)
56 except locale.Error as e:
57 warnings.warn(
58 "Unknown local setting '{}' (current: '{}') - {}."
59 "".format(self.slocale, locale.getlocale(), e))
60 stops = set(_.decode() for _ in self.stops)
61 cout[:] = cin[:]
63 for i in range(0, cin.shape[0]):
64 if isinstance(cout[i], float):
65 # nan
66 cout[i] = '' # pragma: no cover
67 else:
68 cout[i] = self.strip_accents_unicode(cout[i])
70 if self.is_case_sensitive and len(stops) > 0:
71 for i in range(0, cin.shape[0]):
72 cout[i] = self._remove_stopwords(cout[i], stops)
74 if self.case_change_action == b'LOWER':
75 for i in range(0, cin.shape[0]):
76 cout[i] = cout[i].lower()
77 elif self.case_change_action == b'UPPER':
78 for i in range(0, cin.shape[0]):
79 cout[i] = cout[i].upper()
80 elif self.case_change_action != b'NONE':
81 raise RuntimeError(
82 "Unknown option for case_change_action: {}.".format(
83 self.case_change_action))
85 if not self.is_case_sensitive and len(stops) > 0:
86 for i in range(0, cin.shape[0]):
87 cout[i] = self._remove_stopwords(cout[i], stops)
89 return cout
91 def _remove_stopwords(self, text, stops):
92 spl = text.split(' ')
93 return ' '.join(filter(lambda s: s not in stops, spl))
95 def strip_accents_unicode(self, s):
96 """
97 Transforms accentuated unicode symbols into their simple counterpart.
98 Source: `sklearn/feature_extraction/text.py
99 <https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/
100 feature_extraction/text.py#L115>`_.
102 :param s: string
103 The string to strip
104 :return: the cleaned string
105 """
106 try:
107 # If `s` is ASCII-compatible, then it does not contain any accented
108 # characters and we can avoid an expensive list comprehension
109 s.encode("ASCII", errors="strict")
110 return s
111 except UnicodeEncodeError:
112 normalized = unicodedata.normalize('NFKD', s)
113 s = ''.join(
114 [c for c in normalized if not unicodedata.combining(c)])
115 return s
117 def _infer_shapes(self, x): # pylint: disable=E0202,W0221
118 return (x, )