Coverage for mlprodict/onnxrt/ops_cpu/op_cast.py: 91%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

46 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.onnx_pb import TensorProto 

9from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 

10from ._op import OpRun 

11 

12 

13class Cast(OpRun): 

14 

15 atts = {'to': None} 

16 

17 def __init__(self, onnx_node, desc=None, **options): 

18 OpRun.__init__(self, onnx_node, desc=desc, 

19 expected_attributes=Cast.atts, 

20 **options) 

21 if self.to == TensorProto.STRING: # pylint: disable=E1101 

22 self._dtype = numpy.str_ 

23 else: 

24 self._dtype = TENSOR_TYPE_TO_NP_TYPE[self.to] 

25 self._cast = lambda x: x.astype(self._dtype) 

26 

27 def _run(self, x): # pylint: disable=W0221 

28 if self.inplaces.get(0, False): 

29 return self._run_inplace(x) 

30 return (self._cast(x), ) 

31 

32 def _run_inplace(self, x): 

33 if x.dtype == self._dtype: 

34 return (x, ) 

35 return (self._cast(x), ) 

36 

37 def _infer_shapes(self, x): # pylint: disable=W0221 

38 return (x.copy(dtype=self._dtype), ) 

39 

40 def _infer_types(self, x): # pylint: disable=W0221 

41 return (self._dtype, ) 

42 

43 def _infer_sizes(self, *args, **kwargs): 

44 res = self.run(*args, **kwargs) 

45 return (dict(temp=0), ) + res 

46 

47 

48class CastLike(OpRun): 

49 

50 def __init__(self, onnx_node, desc=None, **options): 

51 OpRun.__init__(self, onnx_node, desc=desc, **options) 

52 

53 def _run(self, x, y): # pylint: disable=W0221 

54 if self.inplaces.get(0, False): 

55 return self._run_inplace(x, y) 

56 return (x.astype(y.dtype), ) 

57 

58 def _run_inplace(self, x, y): 

59 if x.dtype == y.dtype: 

60 return (x, ) 

61 return (x.astype(y.dtype), ) 

62 

63 def _infer_shapes(self, x, y): # pylint: disable=W0221 

64 return (x.copy(dtype=y.dtype), ) 

65 

66 def _infer_types(self, x, y): # pylint: disable=W0221 

67 return (y._dtype, ) 

68 

69 def _infer_sizes(self, *args, **kwargs): 

70 res = self.run(*args, **kwargs) 

71 return (dict(temp=0), ) + res