Coverage for mlprodict/onnxrt/ops_cpu/op_tfidfvectorizer.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRunUnary, RuntimeTypeError
9from ..shape_object import ShapeObject
10from .op_tfidfvectorizer_ import RuntimeTfIdfVectorizer # pylint: disable=E0611,E0401
13class TfIdfVectorizer(OpRunUnary):
15 atts = {'max_gram_length': 1,
16 'max_skip_count': 1,
17 'min_gram_length': 1,
18 'mode': b'TF',
19 'ngram_counts': [],
20 'ngram_indexes': [],
21 'pool_int64s': [],
22 'pool_strings': [],
23 'weights': []}
25 def __init__(self, onnx_node, desc=None, **options):
26 OpRunUnary.__init__(self, onnx_node, desc=desc,
27 expected_attributes=TfIdfVectorizer.atts,
28 **options)
29 self.rt_ = RuntimeTfIdfVectorizer()
30 if len(self.pool_strings) != 0:
31 pool_int64s = list(range(len(self.pool_strings)))
32 pool_strings_ = numpy.array(
33 [_.decode('utf-8') for _ in self.pool_strings])
34 mapping = {}
35 for i, w in enumerate(pool_strings_):
36 mapping[w] = i
37 else:
38 mapping = None
39 pool_int64s = self.pool_int64s
40 pool_strings_ = None
42 self.mapping_ = mapping
43 self.pool_strings_ = pool_strings_
44 self.rt_.init(
45 self.max_gram_length, self.max_skip_count, self.min_gram_length,
46 self.mode, self.ngram_counts, self.ngram_indexes, pool_int64s,
47 self.weights)
49 def _run(self, x): # pylint: disable=W0221
50 if self.mapping_ is None:
51 res = self.rt_.compute(x)
52 return (res.reshape((x.shape[0], -1)), )
53 else:
54 xi = numpy.empty(x.shape, dtype=numpy.int64)
55 for i in range(0, x.shape[0]):
56 for j in range(0, x.shape[1]):
57 try:
58 xi[i, j] = self.mapping_[x[i, j]]
59 except KeyError:
60 xi[i, j] = -1
61 res = self.rt_.compute(xi)
62 return (res.reshape((x.shape[0], -1)), )
64 def _infer_shapes(self, x): # pylint: disable=E0202,W0221
65 if x.shape is None:
66 return (x, )
67 if len(x) == 1:
68 return (ShapeObject((x[0], None), dtype=x.dtype,
69 name=self.__class__.__name__), )
70 if len(x) == 2:
71 return (ShapeObject((x[0], x[1], None), dtype=x.dtype,
72 name=self.__class__.__name__), )
73 raise RuntimeTypeError( # pragma: no cover
74 "Only two dimension are allowed, got {}.".format(x))
76 def _infer_types(self, x): # pylint: disable=E0202,W0221
77 return (x, )