Coverage for mlprodict/cli/einsum.py: 93%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Command line to check einsum scenarios.
4"""
5import os
8def einsum_test(equation="abc,cd->abd", shape="30", perm=False,
9 runtime='python', verbose=1, fLOG=print,
10 output=None, number=5, repeat=5):
11 """
12 Investigates whether or not the decomposing einsum is faster.
14 :param equation: einsum equation to test
15 :param shape: an integer (all dimension gets the same size) or
16 a list of shapes in a string separated with `;`) or
17 a list of integer to try out multiple shapes,
18 example: `5`, `(5,5,5),(5,5)`, `5,6`
19 :param perm: check on permutation or all letter permutations
20 :param runtime: `'numpy'`, `'python'`, `'onnxruntime'`
21 :param verbose: verbose
22 :param fLOG: logging function
23 :param output: output file (usually a csv file or an excel file),
24 it requires pandas
25 :param number: usual parameter to measure a function
26 :param repeat: usual parameter to measure a function
28 .. cmdref::
29 :title: Investigates whether or not the decomposing einsum is faster.
30 :cmd: -m mlprodict einsum_test --help
31 :lid: l-cmd-einsum_test
33 The command checks whether or not decomposing an einsum function
34 is faster than einsum implementation.
36 Example::
38 python -m mlprodict einsum_test --equation="abc,cd->abd" --output=res.csv
39 """
40 from ..testing.einsum.einsum_bench import einsum_benchmark # pylint: disable=E0402
42 perm = perm in ('True', '1', 1, True)
43 if "(" not in shape:
44 if "," in shape:
45 shape = list(map(int, shape.split(",")))
46 else:
47 shape = int(shape)
48 else:
49 shapes = shape.replace('(', '').replace(')', '').split(";")
50 shape = []
51 for sh in shapes:
52 spl = sh.split(',')
53 shape.append(tuple(map(int, spl)))
54 verbose = int(verbose)
55 number = int(number)
56 repeat = int(repeat)
58 res = einsum_benchmark(equation=equation, shape=shape, perm=perm,
59 runtime=runtime, use_tqdm=verbose > 0,
60 number=number, repeat=repeat)
61 if output not in ('', None):
62 import pandas
63 df = pandas.DataFrame(res)
64 ext = os.path.splitext(output)[-1]
65 if ext == '.csv':
66 df.to_csv(output, index=False)
67 fLOG('[einsum_test] wrote file %r.' % output)
68 elif ext == '.xlsx':
69 df.to_excel(output, index=False)
70 fLOG('[einsum_test] wrote file %r.' % output)
71 else:
72 raise ValueError( # pragma: no cover
73 "Unknown extension %r in file %r." % (ext, output))
74 else:
75 for r in res:
76 fLOG(r)