Coverage for mlprodict/onnxrt/ops_cpu/op_max_pool.py: 84%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

67 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import itertools 

8import numpy 

9from ..shape_object import ShapeObjectFct 

10from ._op import OpRun 

11from .op_max_pool_ import MaxPoolFloat, MaxPoolDouble # pylint: disable=E0611,E0401 

12 

13 

14def _pool_get_output_shape(auto_pad, input_spatial_shape, kernel_spatial_shape, 

15 strides_spatial): 

16 out_shape = [0] * len(input_spatial_shape) 

17 if auto_pad in (b'SAME_UPPER', b'SAME_LOWER'): 

18 for i in range(len(input_spatial_shape)): # pylint: disable=C0200 

19 out_shape[i] = int( 

20 numpy.ceil( 

21 float(input_spatial_shape[i]) / float(strides_spatial[i]))) 

22 elif auto_pad == b'VALID': 

23 for i in range(len(input_spatial_shape)): # pylint: disable=C0200 

24 out_shape[i] = int( 

25 numpy.ceil(float(input_spatial_shape[i] - (kernel_spatial_shape[i] - 1)) / 

26 float(strides_spatial[i]))) 

27 return out_shape 

28 

29 

30def _pool_impl(padded, x_shape, kernel_shape, strides_shape, 

31 out_shape, pad_shape, pooling_type, 

32 count_include_pad=0): 

33 spatial_size = len(x_shape) - 2 

34 y = numpy.zeros([x_shape[0], x_shape[1]] + list(out_shape)) 

35 

36 for shape in itertools.product( 

37 range(x_shape[0]), range(x_shape[1]), 

38 *[range(int((x_shape[i + 2] + pad_shape[i] - kernel_shape[i]) / 

39 strides_shape[i] + 1)) 

40 for i in range(spatial_size)]): 

41 window = padded[shape[0], shape[1]] 

42 window_vals = numpy.array( 

43 [window[i] for i in list( 

44 itertools.product( 

45 *[range(strides_shape[i] * shape[i + 2], 

46 strides_shape[i] * shape[i + 2] + kernel_shape[i]) 

47 for i in range(spatial_size)]))]) 

48 if pooling_type == b'AVG': 

49 f = numpy.average 

50 elif pooling_type == b'MAX': 

51 f = numpy.max 

52 else: 

53 raise NotImplementedError( # pragma: no cover 

54 "Pooling type '{}' does not support. Should be AVG, MAX." 

55 "".format(pooling_type)) 

56 

57 if count_include_pad == 1 and pooling_type == b'AVG': 

58 y[shape] = f(window_vals) 

59 else: 

60 y[shape] = f(window_vals[numpy.where(~numpy.isnan(window_vals))]) 

61 return y.astype(numpy.float32) 

62 

63 

64class MaxPool(OpRun): 

65 

66 atts = {'auto_pad': b'NOTSET', 'ceil_mode': 0, 'dilations': [], 

67 'kernel_shape': [], 'pads': [], 'storage_order': 0, 

68 'strides': []} 

69 

70 def __init__(self, onnx_node, desc=None, **options): 

71 OpRun.__init__(self, onnx_node, desc=desc, 

72 expected_attributes=MaxPool.atts, 

73 **options) 

74 self.auto_pad_ = self.auto_pad.decode('ascii') 

75 self.nb_outputs = len(onnx_node.output) 

76 self._init() 

77 

78 def _init(self): 

79 self.rt32_ = MaxPoolFloat() 

80 self.rt64_ = MaxPoolDouble() 

81 for rt in [self.rt32_, self.rt64_]: 

82 rt.init(self.auto_pad, 

83 numpy.array(self.dilations, dtype=numpy.int64), 

84 self.ceil_mode, 

85 self.storage_order, 

86 numpy.array(self.kernel_shape, dtype=numpy.int64), 

87 numpy.array(self.pads, dtype=numpy.int64), 

88 numpy.array(self.strides, dtype=numpy.int64)) 

89 

90 def _run(self, X): # pylint: disable=W0221 

91 if X.dtype == numpy.float32: 

92 res = self.rt32_.compute(X) 

93 else: 

94 res = self.rt64_.compute(X) 

95 if self.nb_outputs == 1: 

96 return res[:1] 

97 return res 

98 

99 def _infer_shapes(self, X): # pylint: disable=W0221 

100 

101 def compute_shape1(xshape): 

102 xs = numpy.ones(xshape, dtype=numpy.float32) 

103 res, _ = self.rt32_.compute(xs) 

104 return res.shape 

105 

106 def compute_shape2(xshape): 

107 xs = numpy.ones(xshape, dtype=numpy.float32) 

108 _, res2 = self.rt32_.compute(xs) 

109 return res2.shape 

110 

111 if self.nb_outputs == 1: 

112 return (ShapeObjectFct(compute_shape1, X, name="MaxPool", dtype=X.dtype), ) 

113 return (ShapeObjectFct(compute_shape1, X, name="MaxPool", dtype=X.dtype), 

114 ShapeObjectFct(compute_shape2, X, name="MaxPool", dtype=X.dtype)) 

115 

116 def _infer_types(self, X): # pylint: disable=W0221 

117 if self.nb_outputs == 1: 

118 return (X, ) 

119 return (X, X) 

120 

121 def _infer_sizes(self, *args): # pylint: disable=W0221 

122 res = self.run(*args) 

123 return (dict(temp=0), ) + res