Coverage for mlprodict/onnxrt/ops_cpu/op_one_hot_encoder.py: 76%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ..shape_object import DimensionObject
12class OneHotEncoder(OpRun):
13 """
14 :epkg:`ONNX` specifications does not mention
15 the possibility to change the output type,
16 sparse, dense, float, double.
17 """
19 atts = {'cats_int64s': numpy.empty(0, dtype=numpy.int64),
20 'cats_strings': numpy.empty(0, dtype=numpy.str_),
21 'zeros': 1,
22 }
24 def __init__(self, onnx_node, desc=None, **options):
25 OpRun.__init__(self, onnx_node, desc=desc,
26 expected_attributes=OneHotEncoder.atts,
27 **options)
28 if len(self.cats_int64s) > 0:
29 self.classes_ = {v: i for i, v in enumerate(self.cats_int64s)}
30 elif len(self.cats_strings) > 0:
31 self.classes_ = {v.decode('utf-8'): i for i,
32 v in enumerate(self.cats_strings)}
33 else:
34 raise RuntimeError("No encoding was defined.") # pragma: no cover
36 def _run(self, x): # pylint: disable=W0221
37 shape = x.shape
38 new_shape = shape + (len(self.classes_), )
39 res = numpy.zeros(new_shape, dtype=numpy.float32)
40 if len(x.shape) == 1:
41 for i, v in enumerate(x):
42 j = self.classes_.get(v, -1)
43 if j >= 0:
44 res[i, j] = 1.
45 elif len(x.shape) == 2:
46 for a, row in enumerate(x):
47 for i, v in enumerate(row):
48 j = self.classes_.get(v, -1)
49 if j >= 0:
50 res[a, i, j] = 1.
51 else:
52 raise RuntimeError( # pragma: no cover
53 "This operator is not implemented for shape {}.".format(x.shape))
55 if not self.zeros:
56 red = res.sum(axis=len(res.shape) - 1)
57 if numpy.min(red) == 0:
58 rows = []
59 for i, val in enumerate(red):
60 if val == 0:
61 rows.append(dict(row=i, value=x[i]))
62 if len(rows) > 5:
63 break
64 raise RuntimeError( # pragma no cover
65 "One observation did not have any defined category.\n"
66 "classes: {}\nfirst rows:\n{}\nres:\n{}\nx:\n{}".format(
67 self.classes_, "\n".join(str(_) for _ in rows),
68 res[:5], x[:5]))
70 return (res, )
72 def _infer_shapes(self, x): # pylint: disable=W0221
73 new_shape = x.copy()
74 dim = DimensionObject(len(self.classes_))
75 new_shape.append(dim)
76 new_shape._dtype = numpy.float32
77 new_shape.name = self.onnx_node.name
78 return (new_shape, )
80 def _infer_types(self, x): # pylint: disable=W0221
81 return (numpy.float32, )