Coverage for mlprodict/onnxrt/ops_cpu/op_qlinear_conv.py: 92%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

40 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ..shape_object import ShapeObject 

10from .op_qlinear_conv_ import QLinearConvInt8, QLinearConvUInt8 # pylint: disable=E0611,E0401 

11 

12 

13class QLinearConv(OpRun): 

14 

15 atts = {'auto_pad': 'NOTSET', 

16 'group': 1, 

17 'dilations': [], 

18 'kernel_shape': [], 

19 'pads': [], 

20 'strides': []} 

21 

22 def __init__(self, onnx_node, desc=None, **options): 

23 OpRun.__init__(self, onnx_node, desc=desc, 

24 expected_attributes=QLinearConv.atts, 

25 **options) 

26 self._init() 

27 self._cstu8 = numpy.array([], dtype=numpy.uint8) 

28 self._csti8 = numpy.array([], dtype=numpy.int8) 

29 

30 def _init(self): 

31 self.rtu8_ = QLinearConvUInt8() 

32 self.rti8_ = QLinearConvInt8() 

33 for rt in [self.rtu8_, self.rti8_]: 

34 rt.init(self.auto_pad, 

35 numpy.array(self.dilations, dtype=numpy.int64), 

36 self.group, 

37 numpy.array(self.kernel_shape, dtype=numpy.int64), 

38 numpy.array(self.pads, dtype=numpy.int64), 

39 numpy.array(self.strides, dtype=numpy.int64)) 

40 

41 def _run(self, X, x_scale, x_zero_point, w, w_scale, w_zero_point, # pylint: disable=W0221 

42 y_scale, y_zero_point, B=None): 

43 if X is None: 

44 raise ValueError( # pragma: no cover 

45 "X cannot be None for operator %r, ONNX=%r" % ( 

46 type(self), self.onnx_node)) 

47 if X.dtype == numpy.uint8: 

48 if B is None: 

49 b = self._cstu8 

50 else: 

51 b = B 

52 return (self.rtu8_.compute( 

53 X, x_scale, x_zero_point, w, w_scale, w_zero_point, # pylint: disable=W0221 

54 y_scale, y_zero_point, b), ) 

55 return (self.rti8_.compute( 

56 X, x_scale, x_zero_point, w, w_scale, w_zero_point, # pylint: disable=W0221 

57 y_scale, y_zero_point, B or self._csti8), ) 

58 

59 def _infer_shapes(self, X, x_scale, x_zero_point, w, w_scale, # pylint: disable=W0221 

60 w_zero_point, y_scale, y_zero_point, B=None): 

61 

62 return (ShapeObject(None, dtype=X.dtype), ) 

63 

64 def _infer_types(self, X, x_scale, x_zero_point, w, w_scale, # pylint: disable=W0221 

65 w_zero_point, y_scale, y_zero_point, B=None): 

66 

67 return (X, ) 

68 

69 def _infer_sizes(self, *args, **kwargs): # pylint: disable=W0221 

70 res = self.run(*args, **kwargs) 

71 return (dict(temp=0), ) + res 

72 

73 def _infer_sizes(self, *args, **kwargs): # pylint: disable=W0221 

74 res = self.run(*args, **kwargs) 

75 X = args[0] 

76 C = X.shape[1] 

77 kernel_size = numpy.prod(self.kernel_shape) 

78 kernel_dim = C / self.group * kernel_size 

79 temp = kernel_dim * res[0].size 

80 return (dict(temp=temp * X.dtype.itemsize), ) + res