Coverage for mlprodict/onnxrt/ops_cpu/op_quantize_linear.py: 82%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ...onnx_tools.onnx2py_helper import guess_numpy_type_from_dtype
9from ._op import OpRun
10from ..shape_object import ShapeObject
13class QuantizeLinear(OpRun):
15 atts = {'axis': 1}
16 python_inputs = ['*inputs']
18 def __init__(self, onnx_node, desc=None, **options):
19 OpRun.__init__(self, onnx_node, desc=desc,
20 expected_attributes=QuantizeLinear.atts,
21 **options)
23 def _run(self, *args): # pylint: disable=W0221
24 if len(args[1].shape) > 1:
25 raise RuntimeError( # pragma: no cover
26 "Input 2 must be a vector or a number.")
27 y_scale = args[1]
28 if len(y_scale.shape) > 0 and y_scale.size == 1:
29 y_scale = y_scale[0]
30 if len(y_scale.shape) > 0:
31 new_shape = [1 for s in args[0].shape]
32 new_shape[self.axis] = len(y_scale)
33 x = args[0] / args[1].reshape(new_shape)
34 else:
35 x = args[0] / y_scale
36 if len(args) > 2:
37 dtype = args[2].dtype
38 if len(y_scale.shape) > 0:
39 x += args[2].reshape(new_shape)
40 else:
41 x += args[2]
42 numpy.around(x, 1, out=x)
43 if dtype == numpy.uint8:
44 numpy.clip(x, 0, 255, out=x)
45 elif dtype == numpy.int8:
46 numpy.clip(x, -128, 127, out=x)
47 else:
48 raise RuntimeError( # pragma no cover
49 "Unexpected dtype for input 2 {}.".format(dtype))
50 return (x.astype(dtype), )
52 dtype = numpy.uint8
53 numpy.around(x, 1, out=x)
54 numpy.clip(x, 0, 255, out=x)
55 return (x.astype(dtype), )
57 def _infer_shapes(self, *args): # pylint: disable=W0221
58 if len(args) > 2:
59 dtype = args[2].dtype
60 else:
61 dtype = numpy.uint8
62 return (ShapeObject(args[0].shape, dtype=dtype), )
64 def _infer_types(self, *args): # pylint: disable=W0221
65 if len(args) > 2:
66 if isinstance(args[2], numpy.ndarray):
67 dtype = args[2].dtype
68 dtype = guess_numpy_type_from_dtype(args[2])
69 else:
70 dtype = numpy.uint8
71 return (dtype, )
73 def _infer_sizes(self, *args): # pylint: disable=W0221
74 res = self.run(*args)
75 return (dict(temp=0), ) + res