Coverage for mlprodict/onnxrt/ops_cpu/op_scatter_elements.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

45 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ..shape_object import ShapeObject 

9from ._op import OpRun 

10 

11 

12def scatter_elements(data, indices, updates, axis=0): 

13 """ 

14 :: 

15 // for 3-dim and axis=0 

16 // output[indices[i][j][k]][j][k] = updates[i][j][k] 

17 // for axis 1 

18 // output[i][indices[i][j][k]][k] = updates[i][j][k] 

19 // and so on 

20 """ 

21 if len(data.shape) == 1 and axis == 0: 

22 scattered = numpy.copy(data) 

23 for pos, up in zip(indices, updates): 

24 scattered[pos] = up 

25 return scattered 

26 

27 if axis < 0: 

28 axis = data.ndim + axis 

29 

30 idx_xsection_shape = indices.shape[:axis] + indices.shape[axis + 1:] 

31 

32 def make_slice(arr, axis, i): 

33 slc = [slice(None)] * arr.ndim 

34 slc[axis] = i 

35 return slc 

36 

37 def unpack(packed): 

38 unpacked = packed[0] 

39 for i in range(1, len(packed)): 

40 unpacked = unpacked, packed[i] 

41 return unpacked 

42 

43 # We use indices and axis parameters to create idx 

44 # idx is in a form that can be used as a NumPy advanced 

45 # indices for scattering of updates param. in data 

46 idx = [[unpack(numpy.indices(idx_xsection_shape).reshape(indices.ndim - 1, -1)), 

47 indices[tuple(make_slice(indices, axis, i))].reshape(1, -1)[0]] 

48 for i in range(indices.shape[axis])] 

49 idx = list(numpy.concatenate(idx, axis=1)) 

50 idx.insert(axis, idx.pop()) 

51 

52 # updates_idx is a NumPy advanced indices for indexing 

53 # of elements in the updates 

54 updates_idx = list(idx) 

55 updates_idx.pop(axis) 

56 updates_idx.insert(axis, numpy.repeat(numpy.arange(indices.shape[axis]), 

57 numpy.prod(idx_xsection_shape))) 

58 

59 scattered = numpy.copy(data) 

60 scattered[tuple(idx)] = updates[tuple(updates_idx)] 

61 return scattered 

62 

63 

64class ScatterElements(OpRun): 

65 

66 atts = {'axis': 0} 

67 

68 def __init__(self, onnx_node, desc=None, **options): 

69 OpRun.__init__(self, onnx_node, desc=desc, 

70 **options) 

71 

72 def _run(self, data, indices, updates): # pylint: disable=W0221 

73 res = scatter_elements(data, indices, updates, axis=self.axis) 

74 return (res, ) 

75 

76 def _infer_shapes(self, data, indices, updates): # pylint: disable=W0221 

77 return (ShapeObject(data.shape, dtype=data.dtype), ) 

78 

79 def _infer_types(self, data, indices, updates): # pylint: disable=W0221 

80 return (data, ) 

81 

82 def _infer_sizes(self, *args): # pylint: disable=W0221 

83 res = self.run(*args) 

84 return (dict(temp=0), ) + res