Coverage for mlprodict/testing/einsum/einsum_bench.py: 100%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

81 statements  

1""" 

2@file 

3@brief Function to measure the performance of einsum decomposition. 

4""" 

5from itertools import permutations 

6import numpy 

7from onnx import helper, TensorProto 

8from cpyquickhelper.numbers import measure_time 

9from ... import __max_supported_opset__, get_ir_version 

10from ...tools.ort_wrapper import InferenceSession 

11from ...onnxrt import OnnxInference 

12from .einsum_impl import decompose_einsum_equation, apply_einsum_sequence 

13 

14 

15def _measure_time(stmt, *x, repeat=5, number=5, div_by_number=True, 

16 first_run=True, max_time=None): 

17 """ 

18 Measures a statement and returns the results as a dictionary. 

19 

20 :param stmt: string 

21 :param *x: inputs 

22 :param repeat: average over *repeat* experiment 

23 :param number: number of executions in one row 

24 :param div_by_number: divide by the number of executions 

25 :param first_run: if True, runs the function once before measuring 

26 :param max_time: execute the statement until the total goes 

27 beyond this time (approximatively), *repeat* is ignored, 

28 *div_by_number* must be set to True 

29 :return: dictionary 

30 

31 See `Timer.repeat 

32 <https://docs.python.org/3/library/timeit.html?timeit.Timer.repeat>`_ 

33 for a better understanding of parameter *repeat* and *number*. 

34 The function returns a duration corresponding to 

35 *number* times the execution of the main statement. 

36 """ 

37 if first_run: 

38 try: 

39 stmt(*x) 

40 except RuntimeError as e: # pragma: no cover 

41 raise RuntimeError("{}-{}".format(type(x), x.dtype)) from e 

42 

43 def fct(): 

44 stmt(*x) 

45 

46 if first_run: 

47 fct() 

48 

49 return measure_time(fct, context={}, repeat=repeat, number=number, 

50 div_by_number=div_by_number, max_time=max_time) 

51 

52 

53def _make_einsum_model(equation, opset=__max_supported_opset__): 

54 inputs = equation.split('->')[0].split(',') 

55 

56 model = helper.make_model( 

57 opset_imports=[helper.make_operatorsetid('', opset)], 

58 ir_version=get_ir_version(opset), 

59 producer_name='mlprodict', 

60 producer_version='0.1', 

61 graph=helper.make_graph( 

62 name='einsum_test', 

63 inputs=[ 

64 helper.make_tensor_value_info( 

65 "X%d" % i, TensorProto.FLOAT, None) # pylint: disable=E1101 

66 for i in range(len(inputs))], 

67 outputs=[ 

68 helper.make_tensor_value_info( 

69 "Y", TensorProto.FLOAT, None)], # pylint: disable=E1101 

70 nodes=[ 

71 helper.make_node( 

72 "Einsum", ["X%d" % i for i in range(len(inputs))], ["Y"], 

73 equation=equation) 

74 ] 

75 ) 

76 ) 

77 return model 

78 

79 

80def _make_inputs(equation, shapes): 

81 inputs = equation.split('->')[0].split(',') 

82 dims = [len(i) for i in inputs] 

83 

84 if isinstance(shapes, int): 

85 N = shapes 

86 shapes = [(N, ) * le for le in dims] 

87 else: 

88 if len(shapes) != len(inputs): 

89 raise ValueError( # pragma: no cover 

90 "Unexpected number of shapes %r with equation %r." 

91 "" % (shapes, equation)) 

92 inputs = [numpy.random.randn(*sh) for sh in shapes] 

93 return [i.astype(numpy.float32) for i in inputs] 

94 

95 

96def einsum_benchmark(equation="abc,cd->abd", shape=30, perm=False, 

97 runtime='python', use_tqdm=False, 

98 number=5, repeat=5, opset=__max_supported_opset__): 

99 """ 

100 Investigates whether or not the decomposing einsum is faster. 

101 

102 :param equation: einsum equation to test 

103 :param shape: an integer (all dimension gets the same size) or 

104 a list of shapes in a string separated with `;`) 

105 :param perm: check on permutation or all letter permutations 

106 :param runtime: numpy, python, onnxruntime 

107 :param use_tqdm: show progress 

108 :param output: output file (usually a csv file or an excel file), 

109 it requires pandas 

110 :param number: usual parameter to measure a function 

111 :param repeat: usual parameter to measure a function 

112 :param opset: target opset 

113 :return: list of dictionaries as an iterator 

114 """ 

115 scenarios = [] 

116 if (isinstance(shape, list) and 

117 all(map(lambda t: isinstance(t, int), shape))): 

118 shape_list = shape 

119 else: 

120 shape_list = [shape] 

121 

122 if perm: 

123 if equation.lower() != equation: 

124 raise ValueError( 

125 "Only equations with lower letters are allowed but equation %r " 

126 "is not." % equation) 

127 letters = list(sorted(set( 

128 c for c in equation if "a" <= c < "z" or "A" <= c < "Z"))) 

129 for p in permutations(letters): 

130 replace = {d: c for c, d in zip(letters, p)} 

131 eq = equation 

132 for k, v in replace.items(): 

133 eq = eq.replace(k, v.upper()) 

134 eq = eq.lower() 

135 for dec in ['einsum', 'dec']: 

136 for sh in shape_list: 

137 scenarios.append((eq, runtime, dec, sh)) 

138 else: 

139 for dec in ['einsum', 'dec']: 

140 for sh in shape_list: 

141 scenarios.append((equation, runtime, dec, sh)) 

142 

143 if use_tqdm: 

144 from tqdm import tqdm # pragma: no cover 

145 loop = tqdm(scenarios) # pragma: no cover 

146 else: 

147 loop = scenarios 

148 

149 for eq, rt, dec, sh in loop: 

150 inputs = _make_inputs(equation, sh) 

151 

152 if dec == 'dec': 

153 seq = decompose_einsum_equation(eq, strategy='numpy', clean=True) 

154 else: 

155 seq = None 

156 

157 if rt == 'numpy': 

158 if dec == 'einsum': 

159 fct = lambda *x, eq=eq: numpy.einsum(eq, *x, optimize=True) 

160 else: 

161 fct = lambda *x, seq=seq: apply_einsum_sequence(seq, *x) 

162 elif rt == 'onnxruntime': 

163 if dec == 'einsum': 

164 onx = _make_einsum_model(equation, opset=opset) 

165 else: 

166 onx = seq.to_onnx('Y', *["X%d" % i for i in range(len(inputs))], 

167 opset=opset) 

168 sess = InferenceSession( 

169 onx.SerializeToString(), 

170 providers=['CPUExecutionProvider']) # pylint: disable=W0612 

171 fct = lambda *x, se=sess: se.run( 

172 None, {"X%d" % i: v for i, v in enumerate(x)}) 

173 elif rt == 'python': 

174 if dec == 'einsum': 

175 onx = _make_einsum_model(equation, opset=opset) 

176 else: 

177 onx = seq.to_onnx('Y', *["X%d" % i for i in range(len(inputs))], 

178 opset=opset) 

179 oinf = OnnxInference(onx) # pylint: disable=W0612 

180 fct = lambda *x, oi=oinf: oi.run( 

181 {"X%d" % i: v for i, v in enumerate(x)}) 

182 else: 

183 raise ValueError("Unexpected runtime %r." % rt) 

184 

185 res = _measure_time(fct, *inputs, repeat=repeat, number=number) 

186 res['rt'] = rt 

187 res['dec'] = dec 

188 res['eq'] = eq 

189 res['shapes'] = ";".join( 

190 map(str, [m.shape for m in inputs])).replace(' ', '') 

191 yield res