Coverage for mlprodict/onnxrt/ops_cpu/op_scatter_elements.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ..shape_object import ShapeObject
9from ._op import OpRun
12def scatter_elements(data, indices, updates, axis=0):
13 """
14 ::
15 // for 3-dim and axis=0
16 // output[indices[i][j][k]][j][k] = updates[i][j][k]
17 // for axis 1
18 // output[i][indices[i][j][k]][k] = updates[i][j][k]
19 // and so on
20 """
21 if len(data.shape) == 1 and axis == 0:
22 scattered = numpy.copy(data)
23 for pos, up in zip(indices, updates):
24 scattered[pos] = up
25 return scattered
27 if axis < 0:
28 axis = data.ndim + axis
30 idx_xsection_shape = indices.shape[:axis] + indices.shape[axis + 1:]
32 def make_slice(arr, axis, i):
33 slc = [slice(None)] * arr.ndim
34 slc[axis] = i
35 return slc
37 def unpack(packed):
38 unpacked = packed[0]
39 for i in range(1, len(packed)):
40 unpacked = unpacked, packed[i]
41 return unpacked
43 # We use indices and axis parameters to create idx
44 # idx is in a form that can be used as a NumPy advanced
45 # indices for scattering of updates param. in data
46 idx = [[unpack(numpy.indices(idx_xsection_shape).reshape(indices.ndim - 1, -1)),
47 indices[tuple(make_slice(indices, axis, i))].reshape(1, -1)[0]]
48 for i in range(indices.shape[axis])]
49 idx = list(numpy.concatenate(idx, axis=1))
50 idx.insert(axis, idx.pop())
52 # updates_idx is a NumPy advanced indices for indexing
53 # of elements in the updates
54 updates_idx = list(idx)
55 updates_idx.pop(axis)
56 updates_idx.insert(axis, numpy.repeat(numpy.arange(indices.shape[axis]),
57 numpy.prod(idx_xsection_shape)))
59 scattered = numpy.copy(data)
60 scattered[tuple(idx)] = updates[tuple(updates_idx)]
61 return scattered
64class ScatterElements(OpRun):
66 atts = {'axis': 0}
68 def __init__(self, onnx_node, desc=None, **options):
69 OpRun.__init__(self, onnx_node, desc=desc,
70 **options)
72 def _run(self, data, indices, updates): # pylint: disable=W0221
73 res = scatter_elements(data, indices, updates, axis=self.axis)
74 return (res, )
76 def _infer_shapes(self, data, indices, updates): # pylint: disable=W0221
77 return (ShapeObject(data.shape, dtype=data.dtype), )
79 def _infer_types(self, data, indices, updates): # pylint: disable=W0221
80 return (data, )
82 def _infer_sizes(self, *args): # pylint: disable=W0221
83 res = self.run(*args)
84 return (dict(temp=0), ) + res