Coverage for mlprodict/onnxrt/ops_cpu/op_max_pool.py: 84%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import itertools
8import numpy
9from ..shape_object import ShapeObjectFct
10from ._op import OpRun
11from .op_max_pool_ import MaxPoolFloat, MaxPoolDouble # pylint: disable=E0611,E0401
14def _pool_get_output_shape(auto_pad, input_spatial_shape, kernel_spatial_shape,
15 strides_spatial):
16 out_shape = [0] * len(input_spatial_shape)
17 if auto_pad in (b'SAME_UPPER', b'SAME_LOWER'):
18 for i in range(len(input_spatial_shape)): # pylint: disable=C0200
19 out_shape[i] = int(
20 numpy.ceil(
21 float(input_spatial_shape[i]) / float(strides_spatial[i])))
22 elif auto_pad == b'VALID':
23 for i in range(len(input_spatial_shape)): # pylint: disable=C0200
24 out_shape[i] = int(
25 numpy.ceil(float(input_spatial_shape[i] - (kernel_spatial_shape[i] - 1)) /
26 float(strides_spatial[i])))
27 return out_shape
30def _pool_impl(padded, x_shape, kernel_shape, strides_shape,
31 out_shape, pad_shape, pooling_type,
32 count_include_pad=0):
33 spatial_size = len(x_shape) - 2
34 y = numpy.zeros([x_shape[0], x_shape[1]] + list(out_shape))
36 for shape in itertools.product(
37 range(x_shape[0]), range(x_shape[1]),
38 *[range(int((x_shape[i + 2] + pad_shape[i] - kernel_shape[i]) /
39 strides_shape[i] + 1))
40 for i in range(spatial_size)]):
41 window = padded[shape[0], shape[1]]
42 window_vals = numpy.array(
43 [window[i] for i in list(
44 itertools.product(
45 *[range(strides_shape[i] * shape[i + 2],
46 strides_shape[i] * shape[i + 2] + kernel_shape[i])
47 for i in range(spatial_size)]))])
48 if pooling_type == b'AVG':
49 f = numpy.average
50 elif pooling_type == b'MAX':
51 f = numpy.max
52 else:
53 raise NotImplementedError( # pragma: no cover
54 "Pooling type '{}' does not support. Should be AVG, MAX."
55 "".format(pooling_type))
57 if count_include_pad == 1 and pooling_type == b'AVG':
58 y[shape] = f(window_vals)
59 else:
60 y[shape] = f(window_vals[numpy.where(~numpy.isnan(window_vals))])
61 return y.astype(numpy.float32)
64class MaxPool(OpRun):
66 atts = {'auto_pad': b'NOTSET', 'ceil_mode': 0, 'dilations': [],
67 'kernel_shape': [], 'pads': [], 'storage_order': 0,
68 'strides': []}
70 def __init__(self, onnx_node, desc=None, **options):
71 OpRun.__init__(self, onnx_node, desc=desc,
72 expected_attributes=MaxPool.atts,
73 **options)
74 self.auto_pad_ = self.auto_pad.decode('ascii')
75 self.nb_outputs = len(onnx_node.output)
76 self._init()
78 def _init(self):
79 self.rt32_ = MaxPoolFloat()
80 self.rt64_ = MaxPoolDouble()
81 for rt in [self.rt32_, self.rt64_]:
82 rt.init(self.auto_pad,
83 numpy.array(self.dilations, dtype=numpy.int64),
84 self.ceil_mode,
85 self.storage_order,
86 numpy.array(self.kernel_shape, dtype=numpy.int64),
87 numpy.array(self.pads, dtype=numpy.int64),
88 numpy.array(self.strides, dtype=numpy.int64))
90 def _run(self, X): # pylint: disable=W0221
91 if X.dtype == numpy.float32:
92 res = self.rt32_.compute(X)
93 else:
94 res = self.rt64_.compute(X)
95 if self.nb_outputs == 1:
96 return res[:1]
97 return res
99 def _infer_shapes(self, X): # pylint: disable=W0221
101 def compute_shape1(xshape):
102 xs = numpy.ones(xshape, dtype=numpy.float32)
103 res, _ = self.rt32_.compute(xs)
104 return res.shape
106 def compute_shape2(xshape):
107 xs = numpy.ones(xshape, dtype=numpy.float32)
108 _, res2 = self.rt32_.compute(xs)
109 return res2.shape
111 if self.nb_outputs == 1:
112 return (ShapeObjectFct(compute_shape1, X, name="MaxPool", dtype=X.dtype), )
113 return (ShapeObjectFct(compute_shape1, X, name="MaxPool", dtype=X.dtype),
114 ShapeObjectFct(compute_shape2, X, name="MaxPool", dtype=X.dtype))
116 def _infer_types(self, X): # pylint: disable=W0221
117 if self.nb_outputs == 1:
118 return (X, )
119 return (X, X)
121 def _infer_sizes(self, *args): # pylint: disable=W0221
122 res = self.run(*args)
123 return (dict(temp=0), ) + res