Coverage for mlprodict/testing/einsum/einsum_fct.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Main functions decomposing einsum computation into
4more simple functions.
5"""
6from itertools import permutations
7import time
8import math
9import numpy
10from onnx import helper
11from ...onnx_tools.onnx2py_helper import guess_proto_dtype
12from ...onnxrt.onnx_micro_runtime import OnnxMicroRuntime
13from ... import __max_supported_opset__, get_ir_version
14from .einsum_impl import decompose_einsum_equation, apply_einsum_sequence
15from .einsum_ml import predict_transposition_cost
18_einsum_cache = {}
21def enumerate_cached_einsum():
22 """
23 Enumerates all cached einsum function.
24 """
25 global _einsum_cache # pylint: disable=W0603,W0602
26 for k, v in _einsum_cache.items():
27 yield k, v
30class CachedEinsum:
31 """
32 Stores all the necessary information to cache the preprocessing
33 of a an einsum equation.
35 :param equation: numpy equation
36 :param runtime: see :func:`einsum
37 <mlprodict.testing.einsum.einsum_fct.einsum>`
38 :param opset: ONNX opset
39 :param optimize: finds the best letter permutation
40 :param dtype: dtype
41 :param decompose: to decompose Einsum operator or to keep it as is
42 :param key: key used to cache this class
43 :param strategy: optimization strategy
44 :param verbose: displays progress information
46 The class creates the following attributes:
48 * `equation_` corresponding to the best equivalent equation
49 * `graph_`: the corresponding graph returned by function
50 :func:`decompose_einsum_equation
51 <mlprodict.testing.einsum.einsum_impl.decompose_einsum_equation> `
52 * `onnx_`: if a conversion to onnx is used, stores the onnx graph
53 * `runtime_`: a function used by `__call__`, calls the runtime
54 """
56 def __init__(self, equation, runtime='batch_dot', opset=None,
57 optimize=False, dtype=numpy.float64, decompose=True,
58 strategy=None, verbose=None, key=None):
59 self.equation = equation
60 self.runtime = runtime
61 self.opset = opset
62 self.optimize = optimize
63 self.dtype = dtype
64 self.decompose = decompose
65 self.strategy = strategy
66 self.verbose = verbose
67 self.key = key
69 def __repr__(self):
70 "usual"
71 return "%s(%r, %r, %r, %r, %r, %r, %r, key=%r)" % (
72 self.__class__.__name__, self.equation, self.runtime,
73 self.opset, self.optimize, self.dtype, self.decompose,
74 self.strategy, self.key)
76 def default_inputs(self, N=None):
77 """
78 Returns default inputs (reshaped numpy.arange + 0.7i).
80 :param N: dimension (all dimension have the same size)
82 If *N is None*, N is given a size depending on the number of letters
83 to avoid spending too much time on optimization.
84 """
85 if N is None:
86 letters = set(c for c in self.equation
87 if "a" <= c <= "z" or "A" <= c <= "Z")
88 nn = math.factorial(len(letters))
89 N = max(int(2 ** 11 / nn), 4)
90 N = min(N, 15)
91 inps = self.equation.split('->')[0].split(',')
92 lens = [len(s) for s in inps]
93 inputs = [numpy.arange(N ** d).reshape((N,) * d) for d in lens]
94 inputs = [(i + 0.7 * ii).astype(self.dtype)
95 for ii, i in enumerate(inputs)]
96 return inputs
98 def build(self):
99 """
100 Preprocesses the equation builds whatever is necessary
101 to compute the result of the einsum equation.
102 """
103 if not self.optimize and not hasattr(self, 'equation_'):
104 self.equation_ = self.equation
105 elif self.strategy is None:
106 self.equation_ = self._build_optimize()
107 elif self.strategy == 'ml':
108 self.equation_ = self._build_optimize_ml()
109 else:
110 raise ValueError( # pragma error
111 "Unknown strategy %r." % self.strategy)
112 self.build_runtime()
114 def _build_optimize(self):
115 # loops over all permutations
116 if self.equation.lower() != self.equation:
117 raise RuntimeError( # pragma: no cover
118 "Only lower equation can be optimized, %r is not." % self.equation)
119 letters = list(
120 sorted(set(c for c in self.equation if "a" <= c <= "z")))
121 possible = list(permutations(letters))
122 possible.insert(0, letters)
123 if self.verbose:
124 from tqdm import tqdm # pragma: no cover
125 subset = tqdm(possible) # pragma: no cover
126 else:
127 subset = possible
128 best = []
129 confs = []
130 very_best = None
131 inputs = None
132 for perm in subset:
133 replace = {d: c for c, d in zip(letters, perm)}
134 eq = self.equation
135 for k, v in replace.items():
136 eq = eq.replace(k, v.upper())
137 eq = eq.lower()
138 inst = CachedEinsum(eq, runtime=self.runtime, opset=self.opset,
139 optimize=False, dtype=self.dtype,
140 decompose=self.decompose)
141 inst.build()
142 if inputs is None:
143 inputs = inst.default_inputs()
144 inst(*inputs)
145 ts = time.perf_counter()
146 for _ in range(0, 10):
147 inst(*inputs)
148 delta = time.perf_counter() - ts
149 confs.append((delta, eq))
150 if len(best) < 10:
151 best.append((delta, eq))
152 best.sort()
153 elif delta < best[-1][0]:
154 best[-1] = (delta, eq)
155 best.sort()
156 if self.verbose and (
157 very_best is None or very_best != best[0][0]):
158 very_best = best[0][0]
159 subset.set_description("%1.2g rtbest=%r" % best[0])
160 self.optimized_ = best
161 self.timed_permutations_ = confs
162 return best[0][1]
164 def _build_optimize_ml(self):
165 # loops over all permutations
166 if self.equation.lower() != self.equation:
167 raise RuntimeError( # pragma: no cover
168 "Only lower equation can be optimized, %r is not." % self.equation)
169 letters = list(
170 sorted(set(c for c in self.equation if "a" <= c <= "z")))
171 possible = list(permutations(letters))
172 possible.insert(0, letters)
173 if self.verbose:
174 from tqdm import tqdm # pragma: no cover
175 subset = tqdm(possible) # pragma: no cover
176 else:
177 subset = possible
178 best = []
179 confs = []
180 very_best = None
181 inputs = None
182 for perm in subset:
183 replace = {d: c for c, d in zip(letters, perm)}
184 eq = self.equation
185 for k, v in replace.items():
186 eq = eq.replace(k, v.upper())
187 eq = eq.lower()
188 inst = CachedEinsum(eq, runtime=self.runtime, opset=self.opset,
189 optimize=False, dtype=self.dtype,
190 decompose=self.decompose)
191 inst.build()
192 if inputs is None:
193 inputs = inst.default_inputs()
194 if hasattr(inst, 'onnx_'):
195 onx = inst.onnx_
196 else:
197 from skl2onnx.common.data_types import FloatTensorType # delayed
198 inits = [
199 ('X%d' % i, FloatTensorType(list(inputs[i].shape)))
200 for i in range(len(inputs))]
201 onx = inst.graph_.to_onnx('Y', *inits, opset=self.opset)
203 rt = OnnxMicroRuntime(onx)
204 dict_inputs = {'X%d' % i: inp for i, inp in enumerate(inputs)}
205 out = rt.run(dict_inputs)
207 transposes = []
208 for node in onx.graph.node: # pylint: disable=E1101
209 if node.op_type == 'Transpose':
210 shape = [(d * 10 if d > 1 else d)
211 for d in out[node.input[0]].shape]
212 transposes.append(
213 [shape, list(node.attribute[0].ints)])
215 delta = sum(max(0, predict_transposition_cost(*v))
216 for v in transposes)
218 confs.append((delta, eq))
219 if len(best) < 10:
220 best.append((delta, eq))
221 best.sort()
222 elif delta < best[-1][0]:
223 best[-1] = (delta, eq)
224 best.sort()
225 if self.verbose and (
226 very_best is None or very_best != best[0][0]):
227 very_best = best[0][0]
228 subset.set_description("%1.2g mlbest=%r" % best[0])
229 self.optimized_ = best
230 self.timed_permutations_ = confs
231 return best[0][1]
233 def build_onnx_einsum(self, input_names):
234 """
235 Builds an ONNX graph with a single einsum operator.
236 """
237 opset = (self.opset if self.opset is not None
238 else __max_supported_opset__)
239 ir_version = get_ir_version(opset)
240 proto_type = guess_proto_dtype(
241 numpy.float32 if self.dtype is None else self.dtype)
243 model = helper.make_model(
244 opset_imports=[helper.make_operatorsetid('', opset)],
245 ir_version=ir_version,
246 producer_name='mlprodict',
247 producer_version='0.0.1',
248 graph=helper.make_graph(
249 name='einsum',
250 inputs=[helper.make_tensor_value_info(n, proto_type, None)
251 for n in input_names],
252 outputs=[helper.make_tensor_value_info("Y", proto_type, None)],
253 nodes=[
254 helper.make_node(
255 'Einsum', input_names, ["Y"], equation=self.equation_)]))
256 return model
258 def build_runtime(self):
259 """
260 Builds the runtime associated to the
261 equation `self.equation_`.
262 """
263 if self.decompose:
264 self.graph_ = decompose_einsum_equation(
265 self.equation_, strategy='numpy', clean=True)
266 if self.runtime == 'batch_dot':
267 self.runtime_ = lambda *inputs: apply_einsum_sequence(
268 self.graph_, *inputs)
269 elif self.runtime in ('python', 'onnxruntime1'):
270 from ...onnxrt import OnnxInference
271 n_inputs = len(self.graph_.metadata['lengths']) - 1
272 input_names = ['X%d' % i for i in range(n_inputs)]
273 self.onnx_names_ = input_names
274 onx = self.graph_.to_onnx(
275 'Y', *input_names, opset=self.opset, dtype=self.dtype)
276 self.onnx_ = onx
277 rt = ('python_compiled'
278 if self.runtime == 'python'
279 else self.runtime)
280 self.oinf_ = OnnxInference(
281 self.onnx_, runtime=rt, runtime_options=dict(
282 log_severity_level=3))
283 self.runtime_ = lambda *inputs: self.oinf_.run(
284 {i: v for i, v in zip(self.onnx_names_, inputs)})['Y']
285 else:
286 raise ValueError( # pragma: no cover
287 "Unexpected runtime %r." % self.runtime)
288 else:
289 if self.runtime in ('python', 'onnxruntime1'):
290 from ...onnxrt import OnnxInference
291 n_inputs = len(self.equation.split('->')[0].split(','))
292 input_names = ['X%d' % i for i in range(n_inputs)]
293 self.onnx_ = self.build_onnx_einsum(input_names)
294 self.onnx_names_ = input_names
295 rt = ('python_compiled'
296 if self.runtime == 'python'
297 else self.runtime)
298 self.oinf_ = OnnxInference(
299 self.onnx_, runtime=rt, runtime_options=dict(
300 log_severity_level=3))
301 self.runtime_ = lambda *inputs: self.oinf_.run(
302 {i: v for i, v in zip(self.onnx_names_, inputs)})['Y']
303 else:
304 raise ValueError( # pragma: no cover
305 "Unexpected runtime %r." % self.runtime)
307 def __call__(self, *inputs):
308 """
309 Calls the runtime `self.runtime_`.
310 """
311 if not hasattr(self, 'runtime_'):
312 raise RuntimeError( # pragma: no cover
313 "Method build_runtime was not called.")
314 return self.runtime_(*inputs)
316 @staticmethod
317 def build_einsum(equation, runtime, opset, optimize,
318 dtype, decompose=True, strategy=None,
319 verbose=None, key=None):
320 """
321 Creates an instance of *CachedEinsum*.
322 """
323 inst = CachedEinsum(equation, runtime=runtime, opset=opset,
324 optimize=optimize, dtype=dtype,
325 decompose=decompose, strategy=strategy,
326 verbose=verbose, key=key)
327 inst.build()
328 return inst
331def _einsum(equation, dtype, optimize=False, runtime="batch_dot",
332 cache=True, opset=None, decompose=True, strategy=None,
333 verbose=None):
334 global _einsum_cache # pylint: disable=W0603,W0602
335 cached = None
336 if cache:
337 key = equation, runtime, opset, optimize, dtype, decompose, strategy
338 cached = _einsum_cache.get(key, None)
339 else:
340 key = None
341 if cached is None:
342 cached = CachedEinsum.build_einsum(
343 equation, runtime, opset, optimize,
344 dtype, decompose=decompose, strategy=strategy,
345 verbose=verbose, key=key)
346 else:
347 cache = False
348 if cache:
349 _einsum_cache[key] = cached
350 return cached
353def optimize_decompose_einsum_equation(
354 equation, dtype, optimize=False, runtime="batch_dot",
355 cache=True, opset=None, decompose=True, strategy=None,
356 verbose=None):
357 """
358 Proposes a new implementation of :epkg:`numpy:einsum`.
359 It does not allow expresion using `...` and expects
360 a right member.
362 :param equation: einsum equation
363 :param optimize: permutes all letters to find the best
364 permutation
365 :param runtime: runtime used to compute the results once the
366 computation graph is produced (see below)
367 :param cache: if True, the function stores the preprocessing
368 done for a specific equation, the second call with the same
369 equation is much faster
370 :param opset: ONNX opset to use for some runtimes
371 :param decompose: by default, the function decomposes
372 the equation into more simple operators but it can keep
373 the original ONNX einsum operator.
374 :param strategy: optimisation strategy (see below)
375 :param verbose: display progress if optimize is True
376 :return: einsum result
378 The available runtimes are:
380 * `batch_dot`: the runtime is @see fn apply_einsum_sequence,
381 * `python`: one ONNX graph executed with a python runtime,
382 * `onnxruntime1`: one ONNX graph executed with :epkg:`onnxruntime`.
384 The optimisation strategy can be:
386 * `None`: the same runtime is used to find the best permutation of letters
387 * `'ml'`: a machine learned model is used to predict the
388 best permutation of letters, this model comes from
389 notebook :ref:`onnxoperatorcostrst`.
391 The function works in two steps:
393 * first step analyses the equation to produce a computation graph,
394 this graph can also be converted into ONNX,
395 * second step runs the graph whatever the graph is.
397 The function returns an object of type @see cl CachedEinsum
398 which has the following members after optimization:
400 * `equation_` corresponding to the best equivalent equation
401 * `graph_`: the corresponding graph returned by function
402 :func:`decompose_einsum_equation
403 <mlprodict.testing.einsum.einsum_impl.decompose_einsum_equation> `
404 * `onnx_`: if a conversion to onnx is used, stores the onnx graph
405 * `runtime_`: a function used by `__call__`, calls the runtime
406 * `oinf_`: an object of type @see cl OnnxInference
407 * `timed_permutations_`: memorizes the results of the optimization
409 .. runpython::
410 :showcode:
412 import numpy
413 from mlprodict.testing.einsum import optimize_decompose_einsum_equation
415 seq_opt = optimize_decompose_einsum_equation(
416 "bsnh,btnh->bnts", numpy.float64, strategy='ml', verbose=1,
417 runtime="python", optimize=True)
419 print("best equation:", seq_opt.equation_)
421 """
422 res = _einsum(equation, dtype, optimize=optimize, runtime=runtime,
423 cache=cache, opset=opset, decompose=decompose,
424 strategy=strategy, verbose=verbose)
425 return res
428def einsum(equation, *inputs, optimize=False, runtime="batch_dot",
429 cache=True, opset=None, decompose=True,
430 strategy=None, verbose=None):
431 """
432 Proposes a new implementation of :epkg:`numpy:einsum`.
433 It does not allow expresion using `...` and expects
434 a right member.
436 :param equation: einsum equation
437 :param inputs: inputs
438 :param optimize: permutes all letters to find the best
439 permutation
440 :param runtime: runtime used to compute the results once the
441 computation graph is produced (see below)
442 :param cache: if True, the function stores the preprocessing
443 done for a specific equation, the second call with the same
444 equation is much faster
445 :param opset: ONNX opset to use for some runtimes
446 :param decompose: by default, the function decomposes
447 the equation into more simple operators but it can keep
448 the original ONNX einsum operator.
449 :param strategy: optimisation strategy (see below)
450 :param verbose: display progress if optimize is True
451 :return: einsum result
453 The available runtimes are:
455 * `batch_dot`: the runtime is @see fn apply_einsum_sequence,
456 * `python`: one ONNX graph executed with a python runtime,
457 * `onnxruntime1`: one ONNX graph executed with :epkg:`onnxruntime`.
459 The optimisation strategy can be:
461 * `None`: the same runtime is used to find the best permutation of letters
462 * `'ml'`: a machine learned model is used to predict the
463 best permutation of letters, this model comes from
464 notebook :ref:`onnxoperatorcostrst`.
466 The function works in two steps:
468 * first step analyses the equation to produce a computation graph,
469 this graph can also be converted into ONNX,
470 * second step runs the graph whatever the graph is.
472 Further details are available in the documentation of function
473 @see fn optimize_decompose_einsum_equation.
474 The function works the same way as :epkg:`numpy:einsum`:
476 .. runpython::
477 :showcode:
479 import numpy
480 from mlprodict.testing.einsum import einsum
482 equation = "abc,cd->abd"
484 m1 = numpy.random.randn(2, 2, 2)
485 m2 = numpy.random.randn(2, 2)
487 np = numpy.einsum(equation, m1, m2)
488 print('numpy.einsum')
489 print(np)
491 print('mlprodict.testing.einsum')
492 mp = einsum(equation, m1, m2)
493 print(mp)
495 In some case, the einsum implementation can be optimized by looping
496 on possible permutation:
498 .. runpython::
499 :showcode:
500 :process:
502 import timeit
503 import numpy
504 from mlprodict.testing.einsum import einsum
505 from mlprodict.testing.einsum.einsum_fct import enumerate_cached_einsum
507 equation = "cab,cd->ad"
509 m1 = numpy.random.randn(20, 20, 20)
510 m2 = numpy.random.randn(20, 20)
512 print('numpy.einsum',
513 timeit.timeit('numpy.einsum(equation, m1, m2)',
514 number=200,
515 globals=globals()))
517 einsum(equation, m1, m2)
518 print('einsum',
519 timeit.timeit('einsum(equation, m1, m2)',
520 number=200,
521 globals=globals()))
523 einsum(equation, m1, m2, runtime='python')
524 print('einsum-python',
525 timeit.timeit('einsum(equation, m1, m2, runtime="python")',
526 number=200,
527 globals=globals()))
529 einsum(equation, m1, m2, runtime='onnxruntime1')
530 print('einsum-onnxruntime1',
531 timeit.timeit('einsum(equation, m1, m2, runtime="onnxruntime1")',
532 number=200,
533 globals=globals()))
535 einsum(equation, m1, m2, runtime='onnxruntime1', optimize=True, verbose=1)
536 print('einsum-onnxruntime1',
537 timeit.timeit('einsum(equation, m1, m2, runtime="onnxruntime1", optimize=True)',
538 number=200,
539 globals=globals()))
541 print("list of cached einsum equations")
542 for k, v in enumerate_cached_einsum():
543 print(k, v.equation, v.equation_)
545 The last example shows the time taken by every function:
547 .. runpython::
548 :showcode:
549 :process:
551 import os
552 from pyquickhelper.pycode.profiling import profile
553 import numpy
554 from mlprodict.testing.einsum import einsum
555 from mlprodict.testing.einsum.einsum_fct import enumerate_cached_einsum
556 from mlprodict import __file__ as path
558 root = os.path.dirname(path)
560 equation = "cab,cd->ad"
562 m1 = numpy.random.randn(200, 20, 20)
563 m2 = numpy.random.randn(200, 20)
565 def clean(txt):
566 txt = txt.replace(root, "mlprodict")
567 return "\\n".join(txt.split("\\n")[:30])
569 def fct1():
570 for i in range(100):
571 einsum(equation, m1, m2, cache=False)
573 print("Profile cache with default runtime.")
574 res = profile(fct1)
575 print(root)
576 print(clean(res[1]))
578 def fct2():
579 for i in range(100):
580 einsum(equation, m1, m2, cache=False, runtime='python')
582 print("Profile cache with runtime='python'.")
583 res = profile(fct2)
584 print(root)
585 print(clean(res[1]))
588 def fct3():
589 for i in range(100):
590 einsum(equation, m1, m2, cache=True)
592 einsum(equation, m1, m2, cache=True)
593 print("Profile execution with default runtime.")
594 res = profile(fct3)
595 print(root)
596 print(clean(res[1]))
600 def fct4():
601 for i in range(100):
602 einsum(equation, m1, m2, cache=True, runtime='python')
604 einsum(equation, m1, m2, cache=True, runtime='python')
605 print("Profile execution with runtime='python'.")
606 res = profile(fct4)
607 print(root)
608 print(clean(res[1]))
611 def fct5():
612 for i in range(100):
613 einsum(equation, m1, m2, cache=True, runtime='onnxruntime1')
615 einsum(equation, m1, m2, cache=True, runtime='onnxruntime1')
616 print("Profile execution with runtime='onnxruntime1'.")
617 res = profile(fct5)
618 print(root)
619 print(clean(res[1]))
620 """
621 if len(inputs) == 0:
622 raise ValueError("No inputs found.") # pragma: no cover
623 dtypes = set(i.dtype for i in inputs)
624 if len(dtypes) != 1:
625 raise ValueError( # pragma: no cover
626 "All inputs do not have the same type (%r), "
627 "all of them should be cast before called einsum."
628 "" % dtypes)
629 cached = optimize_decompose_einsum_equation(
630 equation, inputs[0].dtype, optimize=optimize,
631 runtime=runtime, cache=cache, opset=opset,
632 decompose=decompose, strategy=strategy, verbose=verbose)
633 return cached(*inputs)