Coverage for mlprodict/onnxrt/validate/validate_python.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

65 statements  

1""" 

2@file 

3@brief Helpers to validate python code. 

4""" 

5import pickle 

6import pprint 

7import numpy 

8from numpy.linalg import det as npy_det # pylint: disable=E0611 

9from scipy.spatial.distance import cdist # pylint: disable=E0611 

10from scipy.special import expit, erf # pylint: disable=E0611 

11from scipy.linalg import solve # pylint: disable=E0611 

12from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 

13from ...tools.code_helper import make_callable 

14 

15 

16def _make_callable(fct, obj, code, gl, debug): 

17 """ 

18 Same function as @see fn make_callable but deals with 

19 function which an undefined number of arguments. 

20 """ 

21 def pyrt_Concat_(*inputs, axis=0): 

22 return numpy.concatenate(inputs, axis=axis) 

23 

24 if fct == "pyrt_Concat": 

25 return pyrt_Concat_ 

26 return make_callable(fct, obj, code, gl, debug) 

27 

28 

29def validate_python_inference(oinf, inputs, tolerance=0.): 

30 """ 

31 Validates the code produced by method :meth:`to_python 

32 <mlprodict.onnxrt.onnx_inference_exports.OnnxInferenceExport.to_python>`. 

33 The function compiles and executes the code 

34 given as an argument and compares the results to 

35 what *oinf* returns. This function is mostly used for 

36 unit testing purpose but it is not robust enough 

37 to handle all cases. 

38 

39 @param oinf @see cl OnnxInference 

40 @param inputs inputs as dictionary 

41 @param tolerance discrepencies must be below or equal to 

42 this theshold 

43 

44 The function fails if the expected output are not the same. 

45 """ 

46 from ..ops_cpu.op_argmax import _argmax 

47 from ..ops_cpu.op_argmin import _argmin 

48 from ..ops_cpu.op_celu import _vcelu1 

49 from ..ops_cpu.op_leaky_relu import _leaky_relu 

50 

51 cd = oinf.to_python() 

52 code = cd['onnx_pyrt_main.py'] 

53 

54 exp = oinf.run(inputs) 

55 if not isinstance(exp, dict): 

56 raise TypeError( # pragma: no cover 

57 "exp is not a dictionary by '{}'.".format(type(exp))) 

58 if len(exp) == 0: 

59 raise ValueError( # pragma: no cover 

60 "No result to compare.") 

61 inps = ['{0}={0}'.format(k) for k in sorted(inputs)] 

62 code += "\n".join(['', '', 'opi = OnnxPythonInference()', 

63 'res = opi.run(%s)' % ', '.join(inps)]) 

64 

65 try: 

66 cp = compile(code, "<string>", mode='exec') 

67 except SyntaxError as e: 

68 raise SyntaxError( 

69 "Error %s in code\n%s" % (str(e), code)) from e 

70 pyrt_fcts = [_ for _ in cp.co_names if _.startswith("pyrt_")] 

71 fcts_local = {} 

72 

73 gl = {'numpy': numpy, 'pickle': pickle, 

74 'expit': expit, 'erf': erf, 'cdist': cdist, 

75 '_argmax': _argmax, '_argmin': _argmin, 

76 '_vcelu1': _vcelu1, 'solve': solve, 

77 'fft': numpy.fft.fft, 'rfft': numpy.fft.rfft, 

78 'fft2': numpy.fft.fft2, 

79 'npy_det': npy_det, 'ndarray': numpy.ndarray, 

80 '_leaky_relu': _leaky_relu, 

81 'nan': numpy.nan, 

82 'TENSOR_TYPE_TO_NP_TYPE': TENSOR_TYPE_TO_NP_TYPE} 

83 

84 for fct in pyrt_fcts: 

85 for obj in cp.co_consts: 

86 if isinstance(obj, str): 

87 continue 

88 sobj = str(obj) 

89 if '<string>' in sobj and fct in sobj: 

90 fcts_local[fct] = _make_callable(fct, obj, code, gl, False) 

91 

92 gl.update(fcts_local) 

93 loc = inputs 

94 try: 

95 exec(cp, gl, loc) # pylint: disable=W0122 

96 except (NameError, TypeError, SyntaxError, # pragma: no cover 

97 IndexError, ValueError) as e: 

98 raise RuntimeError( 

99 "Unable to execute code.\n{}\n-----\n{}".format(e, code)) from e 

100 

101 got = loc['res'] 

102 keys = list(sorted(exp)) 

103 if isinstance(got, numpy.ndarray) and len(keys) == 1: 

104 got = {keys[0]: got} 

105 

106 if not isinstance(got, dict): 

107 raise TypeError( # pragma: no cover 

108 "got is not a dictionary by '{}'\n--\n{}\n---\n{}\n--code--\n{}".format( 

109 type(got), dir(got), pprint.pformat(str(loc)), code)) 

110 if len(got) != len(exp): 

111 raise RuntimeError( # pragma: no cover 

112 "Different number of results.\nexp: {}\ngot: {}\n--code--\n{}".format( 

113 ", ".join(sorted(exp)), ", ".join(sorted(got)), code)) 

114 

115 if keys != list(sorted(got)): 

116 raise RuntimeError( # pragma: no cover 

117 "Different result names.\nexp: {}\ngot: {}\n--code--\n{}".format( 

118 ", ".join(sorted(exp)), ", ".join(sorted(got)), code)) 

119 

120 for k in keys: 

121 e = exp[k] 

122 g = got[k] 

123 if isinstance(e, numpy.ndarray): 

124 if e.shape != g.shape: 

125 raise ValueError( # pragma: no cover 

126 "Shapes are different {} != {}\n---\n{}\n{}.".format( 

127 e.shape, g.shape, e, g)) 

128 diff = 0 

129 for a, b in zip(e.ravel(), g.ravel()): 

130 if a == b: 

131 continue 

132 if (isinstance(a, float) and isinstance(b, float) and 

133 numpy.isnan(a) and numpy.isnan(b)): 

134 continue # pragma: no cover 

135 diff = max(diff, abs(a - b)) 

136 if tolerance != 'random' and diff > tolerance: 

137 raise ValueError( # pragma: no cover 

138 "Values are different (max diff={}>{})\n--EXP--\n{}\n--GOT--" 

139 "\n{}\n--\n{}".format(diff, tolerance, e, g, code)) 

140 else: 

141 raise NotImplementedError( # pragma: no cover 

142 "Unable to compare values of type '{}'.".format(type(e)))