Coverage for mlprodict/onnxrt/ops_cpu/op_dict_vectorizer.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from scipy.sparse import coo_matrix
9from ._op import OpRun, RuntimeTypeError
10from ..shape_object import ShapeObject
13class DictVectorizer(OpRun):
15 atts = {'int64_vocabulary': numpy.empty(0, dtype=numpy.int64),
16 'string_vocabulary': numpy.empty(0, dtype=numpy.str_)}
18 def __init__(self, onnx_node, desc=None, **options):
19 OpRun.__init__(self, onnx_node, desc=desc,
20 expected_attributes=DictVectorizer.atts,
21 **options)
22 self.dict_labels = {}
23 if len(self.int64_vocabulary) > 0:
24 for i, v in enumerate(self.int64_vocabulary):
25 self.dict_labels[v] = i
26 self.is_int = True
27 else:
28 for i, v in enumerate(self.string_vocabulary):
29 self.dict_labels[v.decode('utf-8')] = i
30 self.is_int = False
31 if len(self.dict_labels) == 0:
32 raise RuntimeError( # pragma: no cover
33 "int64_vocabulary and string_vocabulary cannot be both empty.")
35 def _run(self, x): # pylint: disable=W0221
36 if not isinstance(x, (numpy.ndarray, list)):
37 raise RuntimeTypeError( # pragma: no cover
38 "x must be iterable not {}.".format(type(x)))
39 values = []
40 rows = []
41 cols = []
42 for i, row in enumerate(x):
43 for k, v in row.items():
44 values.append(v)
45 rows.append(i)
46 cols.append(self.dict_labels[k])
47 values = numpy.array(values)
48 rows = numpy.array(rows)
49 cols = numpy.array(cols)
50 return (coo_matrix((values, (rows, cols)), shape=(len(x), len(self.dict_labels))), )
52 def _infer_shapes(self, x): # pylint: disable=W0221
53 pref = str(hex(id(self))[2:])
54 return (ShapeObject(["ndv%s_0" % pref, "N%s_1" % pref], dtype=x.dtype), )
56 def _infer_types(self, x): # pylint: disable=W0221
57 return (x, )