Coverage for mlprodict/onnxrt/ops_cpu/op_dict_vectorizer.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

37 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from scipy.sparse import coo_matrix 

9from ._op import OpRun, RuntimeTypeError 

10from ..shape_object import ShapeObject 

11 

12 

13class DictVectorizer(OpRun): 

14 

15 atts = {'int64_vocabulary': numpy.empty(0, dtype=numpy.int64), 

16 'string_vocabulary': numpy.empty(0, dtype=numpy.str_)} 

17 

18 def __init__(self, onnx_node, desc=None, **options): 

19 OpRun.__init__(self, onnx_node, desc=desc, 

20 expected_attributes=DictVectorizer.atts, 

21 **options) 

22 self.dict_labels = {} 

23 if len(self.int64_vocabulary) > 0: 

24 for i, v in enumerate(self.int64_vocabulary): 

25 self.dict_labels[v] = i 

26 self.is_int = True 

27 else: 

28 for i, v in enumerate(self.string_vocabulary): 

29 self.dict_labels[v.decode('utf-8')] = i 

30 self.is_int = False 

31 if len(self.dict_labels) == 0: 

32 raise RuntimeError( # pragma: no cover 

33 "int64_vocabulary and string_vocabulary cannot be both empty.") 

34 

35 def _run(self, x): # pylint: disable=W0221 

36 if not isinstance(x, (numpy.ndarray, list)): 

37 raise RuntimeTypeError( # pragma: no cover 

38 "x must be iterable not {}.".format(type(x))) 

39 values = [] 

40 rows = [] 

41 cols = [] 

42 for i, row in enumerate(x): 

43 for k, v in row.items(): 

44 values.append(v) 

45 rows.append(i) 

46 cols.append(self.dict_labels[k]) 

47 values = numpy.array(values) 

48 rows = numpy.array(rows) 

49 cols = numpy.array(cols) 

50 return (coo_matrix((values, (rows, cols)), shape=(len(x), len(self.dict_labels))), ) 

51 

52 def _infer_shapes(self, x): # pylint: disable=W0221 

53 pref = str(hex(id(self))[2:]) 

54 return (ShapeObject(["ndv%s_0" % pref, "N%s_1" % pref], dtype=x.dtype), ) 

55 

56 def _infer_types(self, x): # pylint: disable=W0221 

57 return (x, )