Coverage for mlprodict/onnxrt/ops_cpu/op_qlinear_conv.py: 92%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ..shape_object import ShapeObject
10from .op_qlinear_conv_ import QLinearConvInt8, QLinearConvUInt8 # pylint: disable=E0611,E0401
13class QLinearConv(OpRun):
15 atts = {'auto_pad': 'NOTSET',
16 'group': 1,
17 'dilations': [],
18 'kernel_shape': [],
19 'pads': [],
20 'strides': []}
22 def __init__(self, onnx_node, desc=None, **options):
23 OpRun.__init__(self, onnx_node, desc=desc,
24 expected_attributes=QLinearConv.atts,
25 **options)
26 self._init()
27 self._cstu8 = numpy.array([], dtype=numpy.uint8)
28 self._csti8 = numpy.array([], dtype=numpy.int8)
30 def _init(self):
31 self.rtu8_ = QLinearConvUInt8()
32 self.rti8_ = QLinearConvInt8()
33 for rt in [self.rtu8_, self.rti8_]:
34 rt.init(self.auto_pad,
35 numpy.array(self.dilations, dtype=numpy.int64),
36 self.group,
37 numpy.array(self.kernel_shape, dtype=numpy.int64),
38 numpy.array(self.pads, dtype=numpy.int64),
39 numpy.array(self.strides, dtype=numpy.int64))
41 def _run(self, X, x_scale, x_zero_point, w, w_scale, w_zero_point, # pylint: disable=W0221
42 y_scale, y_zero_point, B=None):
43 if X is None:
44 raise ValueError( # pragma: no cover
45 "X cannot be None for operator %r, ONNX=%r" % (
46 type(self), self.onnx_node))
47 if X.dtype == numpy.uint8:
48 if B is None:
49 b = self._cstu8
50 else:
51 b = B
52 return (self.rtu8_.compute(
53 X, x_scale, x_zero_point, w, w_scale, w_zero_point, # pylint: disable=W0221
54 y_scale, y_zero_point, b), )
55 return (self.rti8_.compute(
56 X, x_scale, x_zero_point, w, w_scale, w_zero_point, # pylint: disable=W0221
57 y_scale, y_zero_point, B or self._csti8), )
59 def _infer_shapes(self, X, x_scale, x_zero_point, w, w_scale, # pylint: disable=W0221
60 w_zero_point, y_scale, y_zero_point, B=None):
62 return (ShapeObject(None, dtype=X.dtype), )
64 def _infer_types(self, X, x_scale, x_zero_point, w, w_scale, # pylint: disable=W0221
65 w_zero_point, y_scale, y_zero_point, B=None):
67 return (X, )
69 def _infer_sizes(self, *args, **kwargs): # pylint: disable=W0221
70 res = self.run(*args, **kwargs)
71 return (dict(temp=0), ) + res
73 def _infer_sizes(self, *args, **kwargs): # pylint: disable=W0221
74 res = self.run(*args, **kwargs)
75 X = args[0]
76 C = X.shape[1]
77 kernel_size = numpy.prod(self.kernel_shape)
78 kernel_dim = C / self.group * kernel_size
79 temp = kernel_dim * res[0].size
80 return (dict(temp=temp * X.dtype.itemsize), ) + res