Coverage for mlprodict/onnxrt/ops_cpu/op_gather_elements.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

41 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ..shape_object import ShapeObject 

10 

11 

12def gather_numpy_2(self, dim, index): 

13 res = [] 

14 for a, b in zip(self, index): 

15 res.append(a[b[0]]) 

16 res = numpy.array( 

17 res, dtype=self.dtype).reshape(index.shape) 

18 return res 

19 

20 

21def gather_numpy(self, dim, index): 

22 """ 

23 Gathers values along an axis specified by dim. 

24 For a 3-D tensor the output is specified by: 

25 

26 :: 

27 

28 out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 

29 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 

30 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 

31 

32 :param dim: The axis along which to index 

33 :param index: A tensor of indices of elements to gather 

34 :return: tensor of gathered values 

35 

36 See `How to do scatter and gather operations in numpy? 

37 <https://stackoverflow.com/questions/46065873/ 

38 how-to-do-scatter-and-gather-operations-in-numpy/46204790#46204790>`_ 

39 """ 

40 idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:] 

41 self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:] 

42 if idx_xsection_shape != self_xsection_shape: 

43 raise ValueError( # pragma: no cover 

44 "Except for dimension {}, all dimensions of " 

45 "index and self should be the same size".format(dim)) 

46 data_swaped = numpy.swapaxes(self, 0, dim) 

47 index_swaped = numpy.swapaxes(index, 0, dim) 

48 

49 try: 

50 gathered = numpy.choose(index_swaped, data_swaped, mode='wrap') 

51 except ValueError as e: 

52 if len(index_swaped.shape) == 2 and len(data_swaped.shape) == 2: 

53 return gather_numpy_2(self, dim, index) 

54 raise e # pragma: no cover 

55 

56 return numpy.swapaxes(gathered, 0, dim) 

57 

58 

59class GatherElements(OpRun): 

60 

61 atts = {'axis': 0} 

62 

63 def __init__(self, onnx_node, desc=None, **options): 

64 OpRun.__init__(self, onnx_node, desc=desc, 

65 expected_attributes=GatherElements.atts, 

66 **options) 

67 

68 def _run(self, data, indices): # pylint: disable=W0221 

69 if indices.size == 0: 

70 return (numpy.empty((0, ), dtype=data.dtype), ) 

71 y = gather_numpy(data, self.axis, indices) 

72 return (y, ) 

73 

74 def _infer_shapes(self, data, indices): # pylint: disable=W0221 

75 return (ShapeObject(None, data.dtype), ) 

76 

77 def _infer_types(self, data, indices): # pylint: disable=W0221 

78 return (data, ) 

79 

80 def _infer_sizes(self, *args): # pylint: disable=W0221 

81 res = self.run(*args) 

82 return (dict(temp=sum(a.size * a.dtype.itemsize for a in args)), ) + res 

83 

84 def to_python(self, inputs): 

85 lines = ['data_swaped = numpy.swapaxes(%s, 0, axis)' % inputs[0], 

86 'index_swaped = numpy.swapaxes(%s, 0, axis)' % inputs[1], 

87 "gathered = numpy.choose(index_swaped, data_swaped, mode='wrap')", 

88 'return numpy.swapaxes(gathered, 0, axis)'] 

89 return "import numpy", "\n".join(lines)