Coverage for mlprodict/onnxrt/ops_cpu/op_string_normalizer.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

59 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import unicodedata 

8import locale 

9import warnings 

10import numpy 

11from ._op import OpRunUnary, RuntimeTypeError 

12 

13 

14class StringNormalizer(OpRunUnary): 

15 """ 

16 The operator is not really threadsafe as python cannot 

17 play with two locales at the same time. stop words 

18 should not be implemented here as the tokenization 

19 usually happens after this steps. 

20 """ 

21 

22 atts = {'case_change_action': b'NONE', # LOWER UPPER NONE 

23 'is_case_sensitive': 1, 

24 'locale': b'', 

25 'stopwords': []} 

26 

27 def __init__(self, onnx_node, desc=None, **options): 

28 OpRunUnary.__init__(self, onnx_node, desc=desc, 

29 expected_attributes=StringNormalizer.atts, 

30 **options) 

31 self.slocale = self.locale.decode('ascii') 

32 self.stops = set(self.stopwords) 

33 

34 def _run(self, x): # pylint: disable=W0221 

35 """ 

36 Normalizes strings. 

37 """ 

38 res = numpy.empty(x.shape, dtype=x.dtype) 

39 if len(x.shape) == 2: 

40 for i in range(0, x.shape[1]): 

41 self._run_column(x[:, i], res[:, i]) 

42 elif len(x.shape) == 1: 

43 self._run_column(x, res) 

44 else: 

45 raise RuntimeTypeError( # pragma: no cover 

46 "x must be a matrix or a vector.") 

47 return (res, ) 

48 

49 def _run_column(self, cin, cout): 

50 """ 

51 Normalizes string in a columns. 

52 """ 

53 if locale.getlocale() != self.slocale: 

54 try: 

55 locale.setlocale(locale.LC_ALL, self.slocale) 

56 except locale.Error as e: 

57 warnings.warn( 

58 "Unknown local setting '{}' (current: '{}') - {}." 

59 "".format(self.slocale, locale.getlocale(), e)) 

60 stops = set(_.decode() for _ in self.stops) 

61 cout[:] = cin[:] 

62 

63 for i in range(0, cin.shape[0]): 

64 if isinstance(cout[i], float): 

65 # nan 

66 cout[i] = '' # pragma: no cover 

67 else: 

68 cout[i] = self.strip_accents_unicode(cout[i]) 

69 

70 if self.is_case_sensitive and len(stops) > 0: 

71 for i in range(0, cin.shape[0]): 

72 cout[i] = self._remove_stopwords(cout[i], stops) 

73 

74 if self.case_change_action == b'LOWER': 

75 for i in range(0, cin.shape[0]): 

76 cout[i] = cout[i].lower() 

77 elif self.case_change_action == b'UPPER': 

78 for i in range(0, cin.shape[0]): 

79 cout[i] = cout[i].upper() 

80 elif self.case_change_action != b'NONE': 

81 raise RuntimeError( 

82 "Unknown option for case_change_action: {}.".format( 

83 self.case_change_action)) 

84 

85 if not self.is_case_sensitive and len(stops) > 0: 

86 for i in range(0, cin.shape[0]): 

87 cout[i] = self._remove_stopwords(cout[i], stops) 

88 

89 return cout 

90 

91 def _remove_stopwords(self, text, stops): 

92 spl = text.split(' ') 

93 return ' '.join(filter(lambda s: s not in stops, spl)) 

94 

95 def strip_accents_unicode(self, s): 

96 """ 

97 Transforms accentuated unicode symbols into their simple counterpart. 

98 Source: `sklearn/feature_extraction/text.py 

99 <https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/ 

100 feature_extraction/text.py#L115>`_. 

101 

102 :param s: string 

103 The string to strip 

104 :return: the cleaned string 

105 """ 

106 try: 

107 # If `s` is ASCII-compatible, then it does not contain any accented 

108 # characters and we can avoid an expensive list comprehension 

109 s.encode("ASCII", errors="strict") 

110 return s 

111 except UnicodeEncodeError: 

112 normalized = unicodedata.normalize('NFKD', s) 

113 s = ''.join( 

114 [c for c in normalized if not unicodedata.combining(c)]) 

115 return s 

116 

117 def _infer_shapes(self, x): # pylint: disable=E0202,W0221 

118 return (x, )