Coverage for mlprodict/onnxrt/ops_cpu/op_conv.py: 85%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ..shape_object import ShapeObjectFct
10from .op_conv_ import ConvFloat, ConvDouble # pylint: disable=E0611,E0401
13class Conv(OpRun):
15 atts = {'auto_pad': 'NOTSET', 'group': 1,
16 'dilations': [1, 1],
17 'kernel_shape': [],
18 'pads': [],
19 'strides': [1, 1]}
21 def __init__(self, onnx_node, desc=None, **options):
22 OpRun.__init__(self, onnx_node, desc=desc,
23 expected_attributes=Conv.atts,
24 **options)
25 self._init()
27 def _init(self):
28 self.rt32_ = ConvFloat()
29 self.rt64_ = ConvDouble()
30 for rt in [self.rt32_, self.rt64_]:
31 rt.init(self.auto_pad,
32 numpy.array(self.dilations, dtype=numpy.int64),
33 self.group,
34 numpy.array(self.kernel_shape, dtype=numpy.int64),
35 numpy.array(self.pads, dtype=numpy.int64),
36 numpy.array(self.strides, dtype=numpy.int64))
38 def _run(self, X, W, B=None): # pylint: disable=W0221
39 if X is None:
40 raise ValueError( # pragma: no cover
41 "X cannot be None for operator %r, ONNX=%r" % (
42 type(self), self.onnx_node))
43 if min(X.shape) == 0:
44 raise RuntimeError( # pragma: no cover
45 "Unable to run operator Conv on an empty matrix. "
46 "X.shape=%r." % (X.shape, ))
47 if min(W.shape) == 0:
48 raise RuntimeError( # pragma: no cover
49 "Unable to run operator Conv on an empty matrix. "
50 "W.shape=%r." % (W.shape, ))
51 if B is not None and min(B.shape) == 0:
52 raise RuntimeError( # pragma: no cover
53 "Unable to run operator Conv on an empty matrix. "
54 "B.shape=%r." % (B.shape, ))
55 if X.dtype == numpy.float32:
56 return (self.rt32_.compute(X, W, B), )
57 return (self.rt64_.compute(X, W, B), )
59 def _infer_shapes(self, X, W, B=None): # pylint: disable=W0221
61 def compute_shape(xshape, wshape, bshape):
62 xs = numpy.ones(xshape, dtype=numpy.float32)
63 ws = numpy.ones(wshape, dtype=numpy.float32)
64 bs = (numpy.ones(bshape, dtype=numpy.float32)
65 if bshape is not None else None)
66 res = self.rt32_.compute(xs, ws, bs)
67 return res.shape
69 return (ShapeObjectFct(
70 compute_shape, X, W, B, name="Conv", dtype=X.dtype), )
72 def _infer_types(self, X, W, B=None): # pylint: disable=W0221
73 return (X, )
75 def _infer_sizes(self, X, W, B=None): # pylint: disable=W0221
76 res = self.run(X, W, B=None)
77 C = X.shape[1]
78 kernel_size = numpy.prod(self.kernel_shape)
79 kernel_dim = C / self.group * kernel_size
80 temp = kernel_dim * res[0].size
81 return (dict(temp=temp * X.dtype.itemsize), ) + res