Coverage for mlprodict/onnxrt/ops_cpu/op_slice.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from onnx.defs import onnx_opset_version
8from ..shape_object import ShapeObject
9from ._op import OpRun
12class SliceCommon(OpRun):
14 def __init__(self, onnx_node, desc=None, **options):
15 OpRun.__init__(self, onnx_node, desc=desc,
16 **options)
18 def _run(self, data, starts, ends, axes=None, steps=None): # pylint: disable=W0221
19 if axes is None:
20 if steps is None:
21 slices = [slice(s, e) for s, e in zip(starts, ends)]
22 else:
23 slices = [slice(s, e, d)
24 for s, e, d in zip(starts, ends, steps)]
25 else:
26 if steps is None:
27 slices = [slice(0, a) for a in data.shape]
28 for s, e, a in zip(starts, ends, axes):
29 slices[a] = slice(s, e)
30 else:
31 slices = [slice(0, a) for a in data.shape]
32 for s, e, a, d in zip(starts, ends, axes, steps):
33 slices[a] = slice(s, e, d)
34 try:
35 return (data[tuple(slices)], )
36 except TypeError as e: # pragma: no cover
37 raise TypeError(
38 "Unable to extract slice %r for shape %r." % (slices, data.shape)) from e
40 def _infer_shapes(self, data, starts, ends, axes=None, steps=None): # pylint: disable=W0221
41 pref = str(hex(id(self))[2:])
42 if data.shape is None:
43 return (ShapeObject(None, data.dtype), )
44 shape = ["nslice%s_%d" % (pref, i) for i in range(len(data.shape))]
45 return (ShapeObject(shape, data.dtype), )
47 def _infer_types(self, data, starts, ends, axes=None, steps=None): # pylint: disable=W0221
48 return (data, )
50 def _infer_sizes(self, *args, **kwargs): # pylint: disable=W0221
51 res = self.run(*args, **kwargs)
52 return (dict(temp=0), ) + res
55class Slice_10(SliceCommon):
56 def __init__(self, onnx_node, desc=None, **options):
57 SliceCommon.__init__(self, onnx_node, desc=desc,
58 **options)
61class Slice_1(SliceCommon):
63 atts = {'starts': [], 'ends': [], 'axes': []}
65 def __init__(self, onnx_node, desc=None, **options):
66 SliceCommon.__init__(self, onnx_node, desc=desc,
67 expected_attributes=Slice_1.atts,
68 **options)
69 for f in ['starts', 'ends', 'steps', 'axes']:
70 if not hasattr(self, f):
71 continue
72 if getattr(self, f) is not None and len(getattr(self, f)) == 0:
73 setattr(self, f, None)
75 def _run(self, data): # pylint: disable=W0221
76 return SliceCommon._run(
77 self, data, self.starts, self.ends, self.axes)
79 def _infer_shapes(self, data): # pylint: disable=W0221
80 return SliceCommon._infer_shapes(
81 self, data, self.starts, self.ends, self.axes)
83 def _infer_types(self, data): # pylint: disable=W0221
84 return (data, )
87if onnx_opset_version() >= 10:
88 Slice = Slice_10
89else:
90 Slice = Slice_1 # pragma: no cover