Coverage for mlprodict/onnxrt/ops_cpu/op_tokenizer.py: 93%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import re
8import numpy
9from ._op import OpRunUnary, RuntimeTypeError
10from ._new_ops import OperatorSchema
11from ..shape_object import ShapeObject
14class Tokenizer(OpRunUnary):
15 """
16 See :epkg:`Tokenizer`.
17 """
19 atts = {'mark': 0,
20 'mincharnum': 1,
21 'pad_value': b'#',
22 'separators': [],
23 'tokenexp': b'[a-zA-Z0-9_]+',
24 'tokenexpsplit': 0,
25 'stopwords': []}
27 def __init__(self, onnx_node, desc=None, **options):
28 OpRunUnary.__init__(self, onnx_node, desc=desc,
29 expected_attributes=Tokenizer.atts,
30 **options)
31 self.char_tokenization_ = (
32 self.tokenexp == b'.' or list(self.separators) == [b''])
33 self.stops_ = set(_.decode() for _ in self.stopwords)
34 try:
35 self.str_separators_ = set(_.decode('utf-8')
36 for _ in self.separators)
37 except AttributeError as e: # pragma: no cover
38 raise RuntimeTypeError(
39 "Unable to interpret separators {}.".format(self.separators)) from e
40 if self.tokenexp not in (None, b''):
41 self.tokenexp_ = re.compile(self.tokenexp.decode('utf-8'))
43 def _find_custom_operator_schema(self, op_name):
44 if op_name == "Tokenizer":
45 return TokenizerSchema()
46 raise RuntimeError( # pragma: no cover
47 "Unable to find a schema for operator '{}'.".format(op_name))
49 def _run(self, text): # pylint: disable=W0221
50 if self.char_tokenization_:
51 return self._run_char_tokenization(text, self.stops_)
52 if self.str_separators_ is not None and len(self.str_separators_) > 0:
53 return self._run_sep_tokenization(
54 text, self.stops_, self.str_separators_)
55 if self.tokenexp not in (None, ''):
56 return self._run_regex_tokenization(
57 text, self.stops_, self.tokenexp_)
58 raise RuntimeError( # pragma: no cover
59 "Unable to guess which tokenization to use, sep={}, "
60 "tokenexp='{}'.".format(self.separators, self.tokenexp))
62 def _run_tokenization(self, text, stops, split):
63 """
64 Tokenizes a char level.
65 """
66 max_len = max(map(len, text.flatten()))
67 if self.mark:
68 max_len += 2
69 begin = 1
70 else:
71 begin = 0
72 shape = text.shape + (max_len, )
73 max_pos = 0
74 res = numpy.empty(shape, dtype=text.dtype)
75 if len(text.shape) == 1:
76 res[:] = self.pad_value
77 for i in range(text.shape[0]):
78 pos = begin
79 for c in split(text[i]):
80 if c not in stops:
81 res[i, pos] = c
82 pos += 1
83 if self.mark:
84 res[i, 0] = self.pad_value
85 max_pos = max(pos + 1, max_pos)
86 else:
87 max_pos = max(pos, max_pos)
88 res = res[:, :max_pos]
89 elif len(text.shape) == 2:
90 res[:, :] = self.pad_value
91 for i in range(text.shape[0]):
92 for ii in range(text.shape[1]):
93 pos = begin
94 for c in split(text[i, ii]):
95 if c not in stops:
96 res[i, ii, pos] = c
97 pos += 1
98 if self.mark:
99 res[i, ii, 0] = self.pad_value
100 max_pos = max(pos + 1, max_pos)
101 else:
102 max_pos = max(pos, max_pos)
103 res = res[:, :, :max_pos]
104 else:
105 raise RuntimeError( # pragma: no cover
106 "Only vector or matrices are supported not shape {}.".format(text.shape))
107 return (res, )
109 def _run_char_tokenization(self, text, stops):
110 """
111 Tokenizes y charaters.
112 """
113 def split(t):
114 for c in t:
115 yield c
116 return self._run_tokenization(text, stops, split)
118 def _run_sep_tokenization(self, text, stops, separators):
119 """
120 Tokenizes using separators.
121 The function should use a trie to find text.
122 """
123 def split(t):
124 begin = 0
125 pos = 0
126 while pos < len(t):
127 for sep in separators:
128 if (pos + len(sep) <= len(t) and
129 sep == t[pos: pos + len(sep)]):
130 word = t[begin: pos]
131 yield word
132 begin = pos + len(sep)
133 break
134 pos += 1
135 if begin < pos:
136 word = t[begin: pos]
137 yield word
139 return self._run_tokenization(text, stops, split)
141 def _run_regex_tokenization(self, text, stops, exp):
142 """
143 Tokenizes using separators.
144 The function should use a trie to find text.
145 """
146 if self.tokenexpsplit:
147 def split(t):
148 return filter(lambda x: x, exp.split(t))
149 else:
150 def split(t):
151 return filter(lambda x: x, exp.findall(t))
152 return self._run_tokenization(text, stops, split)
154 def _infer_shapes(self, x): # pylint: disable=E0202,W0221
155 if x.shape is None:
156 return (x, )
157 if len(x) == 1:
158 return (ShapeObject((x[0], None), dtype=x.dtype,
159 name=self.__class__.__name__), )
160 if len(x) == 2:
161 return (ShapeObject((x[0], x[1], None), dtype=x.dtype,
162 name=self.__class__.__name__), )
163 raise RuntimeTypeError( # pragma: no cover
164 "Only two dimension are allowed, got {}.".format(x))
166 def _infer_types(self, x): # pylint: disable=E0202,W0221
167 return (x, )
170class TokenizerSchema(OperatorSchema):
171 """
172 Defines a schema for operators added in this package
173 such as @see cl TreeEnsembleClassifierDouble.
174 """
176 def __init__(self):
177 OperatorSchema.__init__(self, 'Tokenizer')
178 self.attributes = Tokenizer.atts