Coverage for mlprodict/onnxrt/ops_cpu/op_gather_elements.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ..shape_object import ShapeObject
12def gather_numpy_2(self, dim, index):
13 res = []
14 for a, b in zip(self, index):
15 res.append(a[b[0]])
16 res = numpy.array(
17 res, dtype=self.dtype).reshape(index.shape)
18 return res
21def gather_numpy(self, dim, index):
22 """
23 Gathers values along an axis specified by dim.
24 For a 3-D tensor the output is specified by:
26 ::
28 out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
29 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
30 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
32 :param dim: The axis along which to index
33 :param index: A tensor of indices of elements to gather
34 :return: tensor of gathered values
36 See `How to do scatter and gather operations in numpy?
37 <https://stackoverflow.com/questions/46065873/
38 how-to-do-scatter-and-gather-operations-in-numpy/46204790#46204790>`_
39 """
40 idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]
41 self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
42 if idx_xsection_shape != self_xsection_shape:
43 raise ValueError( # pragma: no cover
44 "Except for dimension {}, all dimensions of "
45 "index and self should be the same size".format(dim))
46 data_swaped = numpy.swapaxes(self, 0, dim)
47 index_swaped = numpy.swapaxes(index, 0, dim)
49 try:
50 gathered = numpy.choose(index_swaped, data_swaped, mode='wrap')
51 except ValueError as e:
52 if len(index_swaped.shape) == 2 and len(data_swaped.shape) == 2:
53 return gather_numpy_2(self, dim, index)
54 raise e # pragma: no cover
56 return numpy.swapaxes(gathered, 0, dim)
59class GatherElements(OpRun):
61 atts = {'axis': 0}
63 def __init__(self, onnx_node, desc=None, **options):
64 OpRun.__init__(self, onnx_node, desc=desc,
65 expected_attributes=GatherElements.atts,
66 **options)
68 def _run(self, data, indices): # pylint: disable=W0221
69 if indices.size == 0:
70 return (numpy.empty((0, ), dtype=data.dtype), )
71 y = gather_numpy(data, self.axis, indices)
72 return (y, )
74 def _infer_shapes(self, data, indices): # pylint: disable=W0221
75 return (ShapeObject(None, data.dtype), )
77 def _infer_types(self, data, indices): # pylint: disable=W0221
78 return (data, )
80 def _infer_sizes(self, *args): # pylint: disable=W0221
81 res = self.run(*args)
82 return (dict(temp=sum(a.size * a.dtype.itemsize for a in args)), ) + res
84 def to_python(self, inputs):
85 lines = ['data_swaped = numpy.swapaxes(%s, 0, axis)' % inputs[0],
86 'index_swaped = numpy.swapaxes(%s, 0, axis)' % inputs[1],
87 "gathered = numpy.choose(index_swaped, data_swaped, mode='wrap')",
88 'return numpy.swapaxes(gathered, 0, axis)']
89 return "import numpy", "\n".join(lines)