Coverage for mlprodict/onnxrt/ops_cpu/op_dequantize_linear.py: 86%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

36 statements  

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ..shape_object import ShapeObject 

10 

11 

12class DequantizeLinear(OpRun): 

13 

14 atts = {'axis': 1} 

15 python_inputs = ['*inputs'] 

16 

17 def __init__(self, onnx_node, desc=None, **options): 

18 OpRun.__init__(self, onnx_node, desc=desc, 

19 expected_attributes=DequantizeLinear.atts, 

20 **options) 

21 

22 def _run(self, *args): # pylint: disable=W0221 

23 if len(args[1].shape) > 1: 

24 raise RuntimeError( # pragma: no cover 

25 "Input 2 must be a vector or a number.") 

26 

27 x_scale = args[2] 

28 if len(x_scale.shape) > 0 and x_scale.size == 1: 

29 x_scale = x_scale[0] 

30 if len(args) > 2: 

31 if x_scale.dtype != args[0].dtype: 

32 raise RuntimeError( # pragma no cover 

33 "Type mismatch {} != {} in DequantizeLinear.".format( 

34 args[0].dtype, x_scale.dtype)) 

35 

36 if len(x_scale.shape) > 0: 

37 new_shape = [1 for s in args[0].shape] 

38 new_shape[self.axis] = len(x_scale) 

39 x = args[0].astype(numpy.float32) - x_scale.reshape(new_shape) 

40 y = x * args[1].reshape(new_shape) 

41 else: 

42 x = args[0].astype(numpy.float32) - x_scale 

43 y = x * args[1] 

44 elif len(args[1].shape) > 0: 

45 new_shape = [1 for s in args[0].shape] 

46 new_shape[self.axis] = len(x_scale) 

47 y = args[0].astype(numpy.float32) * x_scale.reshape(new_shape) 

48 else: 

49 y = args[0].astype(numpy.float32) * x_scale 

50 return (y.astype(numpy.float32), ) 

51 

52 def _infer_shapes(self, *args): # pylint: disable=W0221 

53 return (ShapeObject(args[0].shape, dtype=numpy.float32), ) 

54 

55 def _infer_types(self, *args): # pylint: disable=W0221 

56 return (numpy.float32, ) 

57 

58 def _infer_sizes(self, *args): # pylint: disable=W0221 

59 res = self.run(*args) 

60 return (dict(temp=0), ) + res