Coverage for mlprodict/onnxrt/ops_cpu/op_label_encoder.py: 93%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ..shape_object import ShapeObject
9from ._op import OpRun
12class LabelEncoder(OpRun):
14 atts = {'default_float': 0., 'default_int64': -1,
15 'default_string': b'',
16 'keys_floats': numpy.empty(0, dtype=numpy.float32),
17 'keys_int64s': numpy.empty(0, dtype=numpy.int64),
18 'keys_strings': numpy.empty(0, dtype=numpy.str_),
19 'values_floats': numpy.empty(0, dtype=numpy.float32),
20 'values_int64s': numpy.empty(0, dtype=numpy.int64),
21 'values_strings': numpy.empty(0, dtype=numpy.str_),
22 }
24 def __init__(self, onnx_node, desc=None, **options):
25 OpRun.__init__(self, onnx_node, desc=desc,
26 expected_attributes=LabelEncoder.atts,
27 **options)
28 if len(self.keys_floats) > 0 and len(self.values_floats) > 0:
29 self.classes_ = {k: v for k, v in zip(
30 self.keys_floats, self.values_floats)}
31 self.default_ = self.default_float
32 self.dtype_ = numpy.float32
33 elif len(self.keys_floats) > 0 and len(self.values_int64s) > 0:
34 self.classes_ = {k: v for k, v in zip(
35 self.keys_floats, self.values_int64s)}
36 self.default_ = self.default_int64
37 self.dtype_ = numpy.int64
38 elif len(self.keys_int64s) > 0 and len(self.values_int64s) > 0:
39 self.classes_ = {k: v for k, v in zip(
40 self.keys_int64s, self.values_int64s)}
41 self.default_ = self.default_int64
42 self.dtype_ = numpy.int64
43 elif len(self.keys_int64s) > 0 and len(self.values_floats) > 0:
44 self.classes_ = {k: v for k, v in zip(
45 self.keys_int64s, self.values_floats)}
46 self.default_ = self.default_int64
47 self.dtype_ = numpy.float32
48 elif len(self.keys_strings) > 0 and len(self.values_floats) > 0:
49 self.classes_ = {k.decode('utf-8'): v for k, v in zip(
50 self.keys_strings, self.values_floats)}
51 self.default_ = self.default_float
52 self.dtype_ = numpy.float32
53 elif len(self.keys_strings) > 0 and len(self.values_int64s) > 0:
54 self.classes_ = {k.decode('utf-8'): v for k, v in zip(
55 self.keys_strings, self.values_int64s)}
56 self.default_ = self.default_int64
57 self.dtype_ = numpy.int64
58 elif len(self.keys_strings) > 0 and len(self.values_strings) > 0:
59 self.classes_ = {
60 k.decode('utf-8'): v.decode('utf-8') for k, v in zip(
61 self.keys_strings, self.values_strings)}
62 self.default_ = self.default_string
63 self.dtype_ = numpy.array(self.classes_.values).dtype
64 elif len(self.keys_floats) > 0 and len(self.values_strings) > 0:
65 self.classes_ = {k: v.decode('utf-8') for k, v in zip(
66 self.keys_floats, self.values_strings)}
67 self.default_ = self.default_string
68 self.dtype_ = numpy.array(self.classes_.values).dtype
69 elif len(self.keys_int64s) > 0 and len(self.values_strings) > 0:
70 self.classes_ = {k: v.decode('utf-8') for k, v in zip(
71 self.keys_int64s, self.values_strings)}
72 self.default_ = self.default_string
73 self.dtype_ = numpy.array(self.classes_.values).dtype
74 elif hasattr(self, 'classes_strings'):
75 raise RuntimeError( # pragma: no cover
76 "This runtime does not implement version 1 of "
77 "operator LabelEncoder.")
78 else:
79 raise RuntimeError(
80 "No encoding was defined in {}.".format(onnx_node))
81 if len(self.classes_) == 0:
82 raise RuntimeError( # pragma: no cover
83 "Empty classes for LabelEncoder, (onnx_node='{}')\n{}.".format(
84 self.onnx_node.name, onnx_node))
86 def _run(self, x): # pylint: disable=W0221
87 if len(x.shape) > 1:
88 x = numpy.squeeze(x)
89 res = numpy.empty((x.shape[0], ), dtype=self.dtype_)
90 for i in range(0, res.shape[0]):
91 res[i] = self.classes_.get(x[i], self.default_)
92 return (res, )
94 def _infer_shapes(self, x): # pylint: disable=W0221
95 nb = len(self.classes_.values())
96 return (ShapeObject((x[0], nb), dtype=self.dtype_,
97 name="{}-1".format(self.__class__.__name__)), )
99 def _infer_types(self, x): # pylint: disable=W0221
100 return (self.dtype_, )