Coverage for mlprodict/onnxrt/ops_cpu/op_slice.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

52 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from onnx.defs import onnx_opset_version 

8from ..shape_object import ShapeObject 

9from ._op import OpRun 

10 

11 

12class SliceCommon(OpRun): 

13 

14 def __init__(self, onnx_node, desc=None, **options): 

15 OpRun.__init__(self, onnx_node, desc=desc, 

16 **options) 

17 

18 def _run(self, data, starts, ends, axes=None, steps=None): # pylint: disable=W0221 

19 if axes is None: 

20 if steps is None: 

21 slices = [slice(s, e) for s, e in zip(starts, ends)] 

22 else: 

23 slices = [slice(s, e, d) 

24 for s, e, d in zip(starts, ends, steps)] 

25 else: 

26 if steps is None: 

27 slices = [slice(0, a) for a in data.shape] 

28 for s, e, a in zip(starts, ends, axes): 

29 slices[a] = slice(s, e) 

30 else: 

31 slices = [slice(0, a) for a in data.shape] 

32 for s, e, a, d in zip(starts, ends, axes, steps): 

33 slices[a] = slice(s, e, d) 

34 try: 

35 return (data[tuple(slices)], ) 

36 except TypeError as e: # pragma: no cover 

37 raise TypeError( 

38 "Unable to extract slice %r for shape %r." % (slices, data.shape)) from e 

39 

40 def _infer_shapes(self, data, starts, ends, axes=None, steps=None): # pylint: disable=W0221 

41 pref = str(hex(id(self))[2:]) 

42 if data.shape is None: 

43 return (ShapeObject(None, data.dtype), ) 

44 shape = ["nslice%s_%d" % (pref, i) for i in range(len(data.shape))] 

45 return (ShapeObject(shape, data.dtype), ) 

46 

47 def _infer_types(self, data, starts, ends, axes=None, steps=None): # pylint: disable=W0221 

48 return (data, ) 

49 

50 def _infer_sizes(self, *args, **kwargs): # pylint: disable=W0221 

51 res = self.run(*args, **kwargs) 

52 return (dict(temp=0), ) + res 

53 

54 

55class Slice_10(SliceCommon): 

56 def __init__(self, onnx_node, desc=None, **options): 

57 SliceCommon.__init__(self, onnx_node, desc=desc, 

58 **options) 

59 

60 

61class Slice_1(SliceCommon): 

62 

63 atts = {'starts': [], 'ends': [], 'axes': []} 

64 

65 def __init__(self, onnx_node, desc=None, **options): 

66 SliceCommon.__init__(self, onnx_node, desc=desc, 

67 expected_attributes=Slice_1.atts, 

68 **options) 

69 for f in ['starts', 'ends', 'steps', 'axes']: 

70 if not hasattr(self, f): 

71 continue 

72 if getattr(self, f) is not None and len(getattr(self, f)) == 0: 

73 setattr(self, f, None) 

74 

75 def _run(self, data): # pylint: disable=W0221 

76 return SliceCommon._run( 

77 self, data, self.starts, self.ends, self.axes) 

78 

79 def _infer_shapes(self, data): # pylint: disable=W0221 

80 return SliceCommon._infer_shapes( 

81 self, data, self.starts, self.ends, self.axes) 

82 

83 def _infer_types(self, data): # pylint: disable=W0221 

84 return (data, ) 

85 

86 

87if onnx_opset_version() >= 10: 

88 Slice = Slice_10 

89else: 

90 Slice = Slice_1 # pragma: no cover