Coverage for mlprodict/onnxrt/ops_cpu/op_dequantize_linear.py: 86%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
9from ..shape_object import ShapeObject
12class DequantizeLinear(OpRun):
14 atts = {'axis': 1}
15 python_inputs = ['*inputs']
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRun.__init__(self, onnx_node, desc=desc,
19 expected_attributes=DequantizeLinear.atts,
20 **options)
22 def _run(self, *args): # pylint: disable=W0221
23 if len(args[1].shape) > 1:
24 raise RuntimeError( # pragma: no cover
25 "Input 2 must be a vector or a number.")
27 x_scale = args[2]
28 if len(x_scale.shape) > 0 and x_scale.size == 1:
29 x_scale = x_scale[0]
30 if len(args) > 2:
31 if x_scale.dtype != args[0].dtype:
32 raise RuntimeError( # pragma no cover
33 "Type mismatch {} != {} in DequantizeLinear.".format(
34 args[0].dtype, x_scale.dtype))
36 if len(x_scale.shape) > 0:
37 new_shape = [1 for s in args[0].shape]
38 new_shape[self.axis] = len(x_scale)
39 x = args[0].astype(numpy.float32) - x_scale.reshape(new_shape)
40 y = x * args[1].reshape(new_shape)
41 else:
42 x = args[0].astype(numpy.float32) - x_scale
43 y = x * args[1]
44 elif len(args[1].shape) > 0:
45 new_shape = [1 for s in args[0].shape]
46 new_shape[self.axis] = len(x_scale)
47 y = args[0].astype(numpy.float32) * x_scale.reshape(new_shape)
48 else:
49 y = args[0].astype(numpy.float32) * x_scale
50 return (y.astype(numpy.float32), )
52 def _infer_shapes(self, *args): # pylint: disable=W0221
53 return (ShapeObject(args[0].shape, dtype=numpy.float32), )
55 def _infer_types(self, *args): # pylint: disable=W0221
56 return (numpy.float32, )
58 def _infer_sizes(self, *args): # pylint: disable=W0221
59 res = self.run(*args)
60 return (dict(temp=0), ) + res