Coverage for mlprodict/onnxrt/ops_cpu/op_label_encoder.py: 93%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

59 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ..shape_object import ShapeObject 

9from ._op import OpRun 

10 

11 

12class LabelEncoder(OpRun): 

13 

14 atts = {'default_float': 0., 'default_int64': -1, 

15 'default_string': b'', 

16 'keys_floats': numpy.empty(0, dtype=numpy.float32), 

17 'keys_int64s': numpy.empty(0, dtype=numpy.int64), 

18 'keys_strings': numpy.empty(0, dtype=numpy.str_), 

19 'values_floats': numpy.empty(0, dtype=numpy.float32), 

20 'values_int64s': numpy.empty(0, dtype=numpy.int64), 

21 'values_strings': numpy.empty(0, dtype=numpy.str_), 

22 } 

23 

24 def __init__(self, onnx_node, desc=None, **options): 

25 OpRun.__init__(self, onnx_node, desc=desc, 

26 expected_attributes=LabelEncoder.atts, 

27 **options) 

28 if len(self.keys_floats) > 0 and len(self.values_floats) > 0: 

29 self.classes_ = {k: v for k, v in zip( 

30 self.keys_floats, self.values_floats)} 

31 self.default_ = self.default_float 

32 self.dtype_ = numpy.float32 

33 elif len(self.keys_floats) > 0 and len(self.values_int64s) > 0: 

34 self.classes_ = {k: v for k, v in zip( 

35 self.keys_floats, self.values_int64s)} 

36 self.default_ = self.default_int64 

37 self.dtype_ = numpy.int64 

38 elif len(self.keys_int64s) > 0 and len(self.values_int64s) > 0: 

39 self.classes_ = {k: v for k, v in zip( 

40 self.keys_int64s, self.values_int64s)} 

41 self.default_ = self.default_int64 

42 self.dtype_ = numpy.int64 

43 elif len(self.keys_int64s) > 0 and len(self.values_floats) > 0: 

44 self.classes_ = {k: v for k, v in zip( 

45 self.keys_int64s, self.values_floats)} 

46 self.default_ = self.default_int64 

47 self.dtype_ = numpy.float32 

48 elif len(self.keys_strings) > 0 and len(self.values_floats) > 0: 

49 self.classes_ = {k.decode('utf-8'): v for k, v in zip( 

50 self.keys_strings, self.values_floats)} 

51 self.default_ = self.default_float 

52 self.dtype_ = numpy.float32 

53 elif len(self.keys_strings) > 0 and len(self.values_int64s) > 0: 

54 self.classes_ = {k.decode('utf-8'): v for k, v in zip( 

55 self.keys_strings, self.values_int64s)} 

56 self.default_ = self.default_int64 

57 self.dtype_ = numpy.int64 

58 elif len(self.keys_strings) > 0 and len(self.values_strings) > 0: 

59 self.classes_ = { 

60 k.decode('utf-8'): v.decode('utf-8') for k, v in zip( 

61 self.keys_strings, self.values_strings)} 

62 self.default_ = self.default_string 

63 self.dtype_ = numpy.array(self.classes_.values).dtype 

64 elif len(self.keys_floats) > 0 and len(self.values_strings) > 0: 

65 self.classes_ = {k: v.decode('utf-8') for k, v in zip( 

66 self.keys_floats, self.values_strings)} 

67 self.default_ = self.default_string 

68 self.dtype_ = numpy.array(self.classes_.values).dtype 

69 elif len(self.keys_int64s) > 0 and len(self.values_strings) > 0: 

70 self.classes_ = {k: v.decode('utf-8') for k, v in zip( 

71 self.keys_int64s, self.values_strings)} 

72 self.default_ = self.default_string 

73 self.dtype_ = numpy.array(self.classes_.values).dtype 

74 elif hasattr(self, 'classes_strings'): 

75 raise RuntimeError( # pragma: no cover 

76 "This runtime does not implement version 1 of " 

77 "operator LabelEncoder.") 

78 else: 

79 raise RuntimeError( 

80 "No encoding was defined in {}.".format(onnx_node)) 

81 if len(self.classes_) == 0: 

82 raise RuntimeError( # pragma: no cover 

83 "Empty classes for LabelEncoder, (onnx_node='{}')\n{}.".format( 

84 self.onnx_node.name, onnx_node)) 

85 

86 def _run(self, x): # pylint: disable=W0221 

87 if len(x.shape) > 1: 

88 x = numpy.squeeze(x) 

89 res = numpy.empty((x.shape[0], ), dtype=self.dtype_) 

90 for i in range(0, res.shape[0]): 

91 res[i] = self.classes_.get(x[i], self.default_) 

92 return (res, ) 

93 

94 def _infer_shapes(self, x): # pylint: disable=W0221 

95 nb = len(self.classes_.values()) 

96 return (ShapeObject((x[0], nb), dtype=self.dtype_, 

97 name="{}-1".format(self.__class__.__name__)), ) 

98 

99 def _infer_types(self, x): # pylint: disable=W0221 

100 return (self.dtype_, )