Coverage for mlprodict/testing/einsum/einsum_bench.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Function to measure the performance of einsum decomposition.
4"""
5from itertools import permutations
6import numpy
7from onnx import helper, TensorProto
8from cpyquickhelper.numbers import measure_time
9from ... import __max_supported_opset__, get_ir_version
10from ...tools.ort_wrapper import InferenceSession
11from ...onnxrt import OnnxInference
12from .einsum_impl import decompose_einsum_equation, apply_einsum_sequence
15def _measure_time(stmt, *x, repeat=5, number=5, div_by_number=True,
16 first_run=True, max_time=None):
17 """
18 Measures a statement and returns the results as a dictionary.
20 :param stmt: string
21 :param *x: inputs
22 :param repeat: average over *repeat* experiment
23 :param number: number of executions in one row
24 :param div_by_number: divide by the number of executions
25 :param first_run: if True, runs the function once before measuring
26 :param max_time: execute the statement until the total goes
27 beyond this time (approximatively), *repeat* is ignored,
28 *div_by_number* must be set to True
29 :return: dictionary
31 See `Timer.repeat
32 <https://docs.python.org/3/library/timeit.html?timeit.Timer.repeat>`_
33 for a better understanding of parameter *repeat* and *number*.
34 The function returns a duration corresponding to
35 *number* times the execution of the main statement.
36 """
37 if first_run:
38 try:
39 stmt(*x)
40 except RuntimeError as e: # pragma: no cover
41 raise RuntimeError("{}-{}".format(type(x), x.dtype)) from e
43 def fct():
44 stmt(*x)
46 if first_run:
47 fct()
49 return measure_time(fct, context={}, repeat=repeat, number=number,
50 div_by_number=div_by_number, max_time=max_time)
53def _make_einsum_model(equation, opset=__max_supported_opset__):
54 inputs = equation.split('->')[0].split(',')
56 model = helper.make_model(
57 opset_imports=[helper.make_operatorsetid('', opset)],
58 ir_version=get_ir_version(opset),
59 producer_name='mlprodict',
60 producer_version='0.1',
61 graph=helper.make_graph(
62 name='einsum_test',
63 inputs=[
64 helper.make_tensor_value_info(
65 "X%d" % i, TensorProto.FLOAT, None) # pylint: disable=E1101
66 for i in range(len(inputs))],
67 outputs=[
68 helper.make_tensor_value_info(
69 "Y", TensorProto.FLOAT, None)], # pylint: disable=E1101
70 nodes=[
71 helper.make_node(
72 "Einsum", ["X%d" % i for i in range(len(inputs))], ["Y"],
73 equation=equation)
74 ]
75 )
76 )
77 return model
80def _make_inputs(equation, shapes):
81 inputs = equation.split('->')[0].split(',')
82 dims = [len(i) for i in inputs]
84 if isinstance(shapes, int):
85 N = shapes
86 shapes = [(N, ) * le for le in dims]
87 else:
88 if len(shapes) != len(inputs):
89 raise ValueError( # pragma: no cover
90 "Unexpected number of shapes %r with equation %r."
91 "" % (shapes, equation))
92 inputs = [numpy.random.randn(*sh) for sh in shapes]
93 return [i.astype(numpy.float32) for i in inputs]
96def einsum_benchmark(equation="abc,cd->abd", shape=30, perm=False,
97 runtime='python', use_tqdm=False,
98 number=5, repeat=5, opset=__max_supported_opset__):
99 """
100 Investigates whether or not the decomposing einsum is faster.
102 :param equation: einsum equation to test
103 :param shape: an integer (all dimension gets the same size) or
104 a list of shapes in a string separated with `;`)
105 :param perm: check on permutation or all letter permutations
106 :param runtime: numpy, python, onnxruntime
107 :param use_tqdm: show progress
108 :param output: output file (usually a csv file or an excel file),
109 it requires pandas
110 :param number: usual parameter to measure a function
111 :param repeat: usual parameter to measure a function
112 :param opset: target opset
113 :return: list of dictionaries as an iterator
114 """
115 scenarios = []
116 if (isinstance(shape, list) and
117 all(map(lambda t: isinstance(t, int), shape))):
118 shape_list = shape
119 else:
120 shape_list = [shape]
122 if perm:
123 if equation.lower() != equation:
124 raise ValueError(
125 "Only equations with lower letters are allowed but equation %r "
126 "is not." % equation)
127 letters = list(sorted(set(
128 c for c in equation if "a" <= c < "z" or "A" <= c < "Z")))
129 for p in permutations(letters):
130 replace = {d: c for c, d in zip(letters, p)}
131 eq = equation
132 for k, v in replace.items():
133 eq = eq.replace(k, v.upper())
134 eq = eq.lower()
135 for dec in ['einsum', 'dec']:
136 for sh in shape_list:
137 scenarios.append((eq, runtime, dec, sh))
138 else:
139 for dec in ['einsum', 'dec']:
140 for sh in shape_list:
141 scenarios.append((equation, runtime, dec, sh))
143 if use_tqdm:
144 from tqdm import tqdm # pragma: no cover
145 loop = tqdm(scenarios) # pragma: no cover
146 else:
147 loop = scenarios
149 for eq, rt, dec, sh in loop:
150 inputs = _make_inputs(equation, sh)
152 if dec == 'dec':
153 seq = decompose_einsum_equation(eq, strategy='numpy', clean=True)
154 else:
155 seq = None
157 if rt == 'numpy':
158 if dec == 'einsum':
159 fct = lambda *x, eq=eq: numpy.einsum(eq, *x, optimize=True)
160 else:
161 fct = lambda *x, seq=seq: apply_einsum_sequence(seq, *x)
162 elif rt == 'onnxruntime':
163 if dec == 'einsum':
164 onx = _make_einsum_model(equation, opset=opset)
165 else:
166 onx = seq.to_onnx('Y', *["X%d" % i for i in range(len(inputs))],
167 opset=opset)
168 sess = InferenceSession(
169 onx.SerializeToString(),
170 providers=['CPUExecutionProvider']) # pylint: disable=W0612
171 fct = lambda *x, se=sess: se.run(
172 None, {"X%d" % i: v for i, v in enumerate(x)})
173 elif rt == 'python':
174 if dec == 'einsum':
175 onx = _make_einsum_model(equation, opset=opset)
176 else:
177 onx = seq.to_onnx('Y', *["X%d" % i for i in range(len(inputs))],
178 opset=opset)
179 oinf = OnnxInference(onx) # pylint: disable=W0612
180 fct = lambda *x, oi=oinf: oi.run(
181 {"X%d" % i: v for i, v in enumerate(x)})
182 else:
183 raise ValueError("Unexpected runtime %r." % rt)
185 res = _measure_time(fct, *inputs, repeat=repeat, number=number)
186 res['rt'] = rt
187 res['dec'] = dec
188 res['eq'] = eq
189 res['shapes'] = ";".join(
190 map(str, [m.shape for m in inputs])).replace(' ', '')
191 yield res