Coverage for mlprodict/onnxrt/ops_cpu/op_quantize_linear.py: 82%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

50 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ...onnx_tools.onnx2py_helper import guess_numpy_type_from_dtype 

9from ._op import OpRun 

10from ..shape_object import ShapeObject 

11 

12 

13class QuantizeLinear(OpRun): 

14 

15 atts = {'axis': 1} 

16 python_inputs = ['*inputs'] 

17 

18 def __init__(self, onnx_node, desc=None, **options): 

19 OpRun.__init__(self, onnx_node, desc=desc, 

20 expected_attributes=QuantizeLinear.atts, 

21 **options) 

22 

23 def _run(self, *args): # pylint: disable=W0221 

24 if len(args[1].shape) > 1: 

25 raise RuntimeError( # pragma: no cover 

26 "Input 2 must be a vector or a number.") 

27 y_scale = args[1] 

28 if len(y_scale.shape) > 0 and y_scale.size == 1: 

29 y_scale = y_scale[0] 

30 if len(y_scale.shape) > 0: 

31 new_shape = [1 for s in args[0].shape] 

32 new_shape[self.axis] = len(y_scale) 

33 x = args[0] / args[1].reshape(new_shape) 

34 else: 

35 x = args[0] / y_scale 

36 if len(args) > 2: 

37 dtype = args[2].dtype 

38 if len(y_scale.shape) > 0: 

39 x += args[2].reshape(new_shape) 

40 else: 

41 x += args[2] 

42 numpy.around(x, 1, out=x) 

43 if dtype == numpy.uint8: 

44 numpy.clip(x, 0, 255, out=x) 

45 elif dtype == numpy.int8: 

46 numpy.clip(x, -128, 127, out=x) 

47 else: 

48 raise RuntimeError( # pragma no cover 

49 "Unexpected dtype for input 2 {}.".format(dtype)) 

50 return (x.astype(dtype), ) 

51 

52 dtype = numpy.uint8 

53 numpy.around(x, 1, out=x) 

54 numpy.clip(x, 0, 255, out=x) 

55 return (x.astype(dtype), ) 

56 

57 def _infer_shapes(self, *args): # pylint: disable=W0221 

58 if len(args) > 2: 

59 dtype = args[2].dtype 

60 else: 

61 dtype = numpy.uint8 

62 return (ShapeObject(args[0].shape, dtype=dtype), ) 

63 

64 def _infer_types(self, *args): # pylint: disable=W0221 

65 if len(args) > 2: 

66 if isinstance(args[2], numpy.ndarray): 

67 dtype = args[2].dtype 

68 dtype = guess_numpy_type_from_dtype(args[2]) 

69 else: 

70 dtype = numpy.uint8 

71 return (dtype, ) 

72 

73 def _infer_sizes(self, *args): # pylint: disable=W0221 

74 res = self.run(*args) 

75 return (dict(temp=0), ) + res