Coverage for mlprodict/onnxrt/ops_cpu/op_category_mapper.py: 86%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

35 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class CategoryMapper(OpRun): 

12 

13 atts = {'cats_int64s': numpy.empty(0, dtype=numpy.int64), 

14 'cats_strings': numpy.empty(0, dtype=numpy.str_), 

15 'default_int64': -1, 

16 'default_string': b'', 

17 } 

18 

19 def __init__(self, onnx_node, desc=None, **options): 

20 OpRun.__init__(self, onnx_node, desc=desc, 

21 expected_attributes=CategoryMapper.atts, 

22 **options) 

23 if len(self.cats_int64s) != len(self.cats_strings): 

24 raise RuntimeError( # pragma: no cover 

25 "Lengths mismatch between cats_int64s (%d) and " 

26 "cats_strings (%d)." % ( 

27 len(self.cats_int64s), len(self.cats_strings))) 

28 self.int2str_ = {} 

29 self.str2int_ = {} 

30 for a, b in zip(self.cats_int64s, self.cats_strings): 

31 be = b.decode('utf-8') 

32 self.int2str_[a] = be 

33 self.str2int_[be] = a 

34 

35 def _run(self, x): # pylint: disable=W0221 

36 if x.dtype == numpy.int64: 

37 xf = x.ravel() 

38 res = [self.int2str_.get(xf[i], self.default_string) 

39 for i in range(0, xf.shape[0])] 

40 return (numpy.array(res).reshape(x.shape), ) 

41 

42 xf = x.ravel() 

43 res = numpy.empty((xf.shape[0], ), dtype=numpy.int64) 

44 for i in range(0, res.shape[0]): 

45 res[i] = self.str2int_.get(xf[i], self.default_int64) 

46 return (res.reshape(x.shape), ) 

47 

48 def _infer_shapes(self, x): # pylint: disable=W0221 

49 if x.dtype == numpy.int64: 

50 return (x.copy(dtype=numpy.str_), ) 

51 return (x.copy(dtype=numpy.int64), ) 

52 

53 def _infer_types(self, x): # pylint: disable=W0221 

54 if x.dtype == numpy.int64: 

55 return (numpy.str_, ) 

56 return (numpy.int64, ) 

57 

58 def _infer_sizes(self, *args, **kwargs): 

59 res = self.run(*args, **kwargs) 

60 return (dict(temp=0), ) + res