Coverage for mlprodict/onnxrt/ops_cpu/op_tokenizer.py: 93%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

104 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import re 

8import numpy 

9from ._op import OpRunUnary, RuntimeTypeError 

10from ._new_ops import OperatorSchema 

11from ..shape_object import ShapeObject 

12 

13 

14class Tokenizer(OpRunUnary): 

15 """ 

16 See :epkg:`Tokenizer`. 

17 """ 

18 

19 atts = {'mark': 0, 

20 'mincharnum': 1, 

21 'pad_value': b'#', 

22 'separators': [], 

23 'tokenexp': b'[a-zA-Z0-9_]+', 

24 'tokenexpsplit': 0, 

25 'stopwords': []} 

26 

27 def __init__(self, onnx_node, desc=None, **options): 

28 OpRunUnary.__init__(self, onnx_node, desc=desc, 

29 expected_attributes=Tokenizer.atts, 

30 **options) 

31 self.char_tokenization_ = ( 

32 self.tokenexp == b'.' or list(self.separators) == [b'']) 

33 self.stops_ = set(_.decode() for _ in self.stopwords) 

34 try: 

35 self.str_separators_ = set(_.decode('utf-8') 

36 for _ in self.separators) 

37 except AttributeError as e: # pragma: no cover 

38 raise RuntimeTypeError( 

39 "Unable to interpret separators {}.".format(self.separators)) from e 

40 if self.tokenexp not in (None, b''): 

41 self.tokenexp_ = re.compile(self.tokenexp.decode('utf-8')) 

42 

43 def _find_custom_operator_schema(self, op_name): 

44 if op_name == "Tokenizer": 

45 return TokenizerSchema() 

46 raise RuntimeError( # pragma: no cover 

47 "Unable to find a schema for operator '{}'.".format(op_name)) 

48 

49 def _run(self, text): # pylint: disable=W0221 

50 if self.char_tokenization_: 

51 return self._run_char_tokenization(text, self.stops_) 

52 if self.str_separators_ is not None and len(self.str_separators_) > 0: 

53 return self._run_sep_tokenization( 

54 text, self.stops_, self.str_separators_) 

55 if self.tokenexp not in (None, ''): 

56 return self._run_regex_tokenization( 

57 text, self.stops_, self.tokenexp_) 

58 raise RuntimeError( # pragma: no cover 

59 "Unable to guess which tokenization to use, sep={}, " 

60 "tokenexp='{}'.".format(self.separators, self.tokenexp)) 

61 

62 def _run_tokenization(self, text, stops, split): 

63 """ 

64 Tokenizes a char level. 

65 """ 

66 max_len = max(map(len, text.flatten())) 

67 if self.mark: 

68 max_len += 2 

69 begin = 1 

70 else: 

71 begin = 0 

72 shape = text.shape + (max_len, ) 

73 max_pos = 0 

74 res = numpy.empty(shape, dtype=text.dtype) 

75 if len(text.shape) == 1: 

76 res[:] = self.pad_value 

77 for i in range(text.shape[0]): 

78 pos = begin 

79 for c in split(text[i]): 

80 if c not in stops: 

81 res[i, pos] = c 

82 pos += 1 

83 if self.mark: 

84 res[i, 0] = self.pad_value 

85 max_pos = max(pos + 1, max_pos) 

86 else: 

87 max_pos = max(pos, max_pos) 

88 res = res[:, :max_pos] 

89 elif len(text.shape) == 2: 

90 res[:, :] = self.pad_value 

91 for i in range(text.shape[0]): 

92 for ii in range(text.shape[1]): 

93 pos = begin 

94 for c in split(text[i, ii]): 

95 if c not in stops: 

96 res[i, ii, pos] = c 

97 pos += 1 

98 if self.mark: 

99 res[i, ii, 0] = self.pad_value 

100 max_pos = max(pos + 1, max_pos) 

101 else: 

102 max_pos = max(pos, max_pos) 

103 res = res[:, :, :max_pos] 

104 else: 

105 raise RuntimeError( # pragma: no cover 

106 "Only vector or matrices are supported not shape {}.".format(text.shape)) 

107 return (res, ) 

108 

109 def _run_char_tokenization(self, text, stops): 

110 """ 

111 Tokenizes y charaters. 

112 """ 

113 def split(t): 

114 for c in t: 

115 yield c 

116 return self._run_tokenization(text, stops, split) 

117 

118 def _run_sep_tokenization(self, text, stops, separators): 

119 """ 

120 Tokenizes using separators. 

121 The function should use a trie to find text. 

122 """ 

123 def split(t): 

124 begin = 0 

125 pos = 0 

126 while pos < len(t): 

127 for sep in separators: 

128 if (pos + len(sep) <= len(t) and 

129 sep == t[pos: pos + len(sep)]): 

130 word = t[begin: pos] 

131 yield word 

132 begin = pos + len(sep) 

133 break 

134 pos += 1 

135 if begin < pos: 

136 word = t[begin: pos] 

137 yield word 

138 

139 return self._run_tokenization(text, stops, split) 

140 

141 def _run_regex_tokenization(self, text, stops, exp): 

142 """ 

143 Tokenizes using separators. 

144 The function should use a trie to find text. 

145 """ 

146 if self.tokenexpsplit: 

147 def split(t): 

148 return filter(lambda x: x, exp.split(t)) 

149 else: 

150 def split(t): 

151 return filter(lambda x: x, exp.findall(t)) 

152 return self._run_tokenization(text, stops, split) 

153 

154 def _infer_shapes(self, x): # pylint: disable=E0202,W0221 

155 if x.shape is None: 

156 return (x, ) 

157 if len(x) == 1: 

158 return (ShapeObject((x[0], None), dtype=x.dtype, 

159 name=self.__class__.__name__), ) 

160 if len(x) == 2: 

161 return (ShapeObject((x[0], x[1], None), dtype=x.dtype, 

162 name=self.__class__.__name__), ) 

163 raise RuntimeTypeError( # pragma: no cover 

164 "Only two dimension are allowed, got {}.".format(x)) 

165 

166 def _infer_types(self, x): # pylint: disable=E0202,W0221 

167 return (x, ) 

168 

169 

170class TokenizerSchema(OperatorSchema): 

171 """ 

172 Defines a schema for operators added in this package 

173 such as @see cl TreeEnsembleClassifierDouble. 

174 """ 

175 

176 def __init__(self): 

177 OperatorSchema.__init__(self, 'Tokenizer') 

178 self.attributes = Tokenizer.atts