Coverage for mlprodict/onnxrt/ops_cpu/op_category_mapper.py: 86%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class CategoryMapper(OpRun):
13 atts = {'cats_int64s': numpy.empty(0, dtype=numpy.int64),
14 'cats_strings': numpy.empty(0, dtype=numpy.str_),
15 'default_int64': -1,
16 'default_string': b'',
17 }
19 def __init__(self, onnx_node, desc=None, **options):
20 OpRun.__init__(self, onnx_node, desc=desc,
21 expected_attributes=CategoryMapper.atts,
22 **options)
23 if len(self.cats_int64s) != len(self.cats_strings):
24 raise RuntimeError( # pragma: no cover
25 "Lengths mismatch between cats_int64s (%d) and "
26 "cats_strings (%d)." % (
27 len(self.cats_int64s), len(self.cats_strings)))
28 self.int2str_ = {}
29 self.str2int_ = {}
30 for a, b in zip(self.cats_int64s, self.cats_strings):
31 be = b.decode('utf-8')
32 self.int2str_[a] = be
33 self.str2int_[be] = a
35 def _run(self, x): # pylint: disable=W0221
36 if x.dtype == numpy.int64:
37 xf = x.ravel()
38 res = [self.int2str_.get(xf[i], self.default_string)
39 for i in range(0, xf.shape[0])]
40 return (numpy.array(res).reshape(x.shape), )
42 xf = x.ravel()
43 res = numpy.empty((xf.shape[0], ), dtype=numpy.int64)
44 for i in range(0, res.shape[0]):
45 res[i] = self.str2int_.get(xf[i], self.default_int64)
46 return (res.reshape(x.shape), )
48 def _infer_shapes(self, x): # pylint: disable=W0221
49 if x.dtype == numpy.int64:
50 return (x.copy(dtype=numpy.str_), )
51 return (x.copy(dtype=numpy.int64), )
53 def _infer_types(self, x): # pylint: disable=W0221
54 if x.dtype == numpy.int64:
55 return (numpy.str_, )
56 return (numpy.int64, )
58 def _infer_sizes(self, *args, **kwargs):
59 res = self.run(*args, **kwargs)
60 return (dict(temp=0), ) + res