Coverage for mlprodict/testing/einsum/einsum_ml.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Functions used to predict the cost of a transposition.
4"""
5import numpy
8_ml_transpose_coefs = {
9 'CST_': 0.4720163707200312,
10 'begin': 0.0,
11 'dbegin': 0.0,
12 'dend': 0.0,
13 'dim': 0.0,
14 'discont': 0.0180766756730043,
15 'edit': 0.06940318842803926,
16 'end': 0.0,
17 'end16': 0.0,
18 'end32': 0.0,
19 'ibegin16': 0.0,
20 'ibegin2': 0.0,
21 'ibegin32': 0.0,
22 'ibegin4': 0.0,
23 'ibegin64': 0.0,
24 'ibegin8': 0.04389296884016416,
25 'iend16': 0.5316238365817172,
26 'iend2': 0.16287259236456927,
27 'iend32': 0.0,
28 'iend4': 0.0,
29 'iend64': 0.0,
30 'iend8': 0.0,
31 'middle': 1.3381940773605624e-06,
32 'rbegin': 0.0,
33 'rdiscont': 0.0,
34 'redit': 0.18604684802855143,
35 'rend': 0.0,
36 'rend16': 0.0,
37 'rend32': 0.0,
38 'rev': 0.42909943168149206,
39 'rmiddle': 0.0,
40 'rot': 0.22272566615803094,
41 'size': 2.8663794075460607e-06}
44def _edit_distance(mot1, mot2):
45 dist = {(-1, -1): 0}
46 if len(mot1) == 0:
47 for j, d in enumerate(mot2):
48 dist[-1, j] = dist[-1, j - 1] + 1
49 dist[j, -1] = dist[j - 1, -1] + 1
50 for i, c in enumerate(mot1):
51 dist[i, -1] = dist[i - 1, -1] + 1
52 dist[-1, i] = dist[-1, i - 1] + 1
53 for j, d in enumerate(mot2):
54 opt = []
55 if (i - 1, j) in dist:
56 x = dist[i - 1, j] + 1
57 opt.append((x, (i - 1, j)))
58 if (i, j - 1) in dist:
59 x = dist[i, j - 1] + 1
60 opt.append((x, (i, j - 1)))
61 if (i - 1, j - 1) in dist:
62 x = dist[i - 1, j - 1] + (1 if c != d else 0)
63 opt.append((x, (i - 1, j - 1)))
64 mi = min(opt)
65 dist[i, j] = mi[0]
67 return dist[len(mot1) - 1, len(mot2) - 1]
70def _is_rotation(perm):
71 t = tuple(perm)
72 c = list(range(len(perm)))
73 for i in range(len(c)):
74 for k in range(len(c)): # pylint: disable=C0200
75 c[k] = (k + i) % len(c)
76 if t == tuple(c):
77 return True
78 return False
81def _relu(x, origin=0):
82 return origin if x < origin else x
85def compute_transposition_features(shape, perm):
86 """
87 Given a shape and a permutation, computes many features
88 used to predict the cost of the transposition.
90 :param shape: shape
91 :param perm: permutation
92 :return: dictionary of features
94 .. runpython::
95 :showcode:
97 import pprint
98 from mlprodict.testing.einsum.einsum_ml import (
99 compute_transposition_features)
101 pprint.pprint(
102 compute_transposition_features((3, 5, 7), (2, 1, 0)))
103 """
104 total = numpy.prod(numpy.array(shape, dtype=numpy.int64))
106 begin = 1
107 dbegin = 0
108 for i, p in enumerate(perm):
109 if p != i:
110 break
111 dbegin += 1
112 begin *= shape[i]
114 end = 1
115 dend = 0
116 for i in range(len(perm) - 1, -1, -1):
117 if perm[i] != i:
118 break
119 dend += 1
120 end *= shape[i]
122 dis_cont = 0
123 for i in range(1, len(shape)):
124 if perm[i] != perm[i - 1] + 1:
125 dis_cont += 1
127 middle = max(1, int(total / (end * begin)))
128 feat = dict(size=total, begin=begin, end=end, middle=middle,
129 dim=len(shape), discont=dis_cont)
131 for c in [16, 32]:
132 feat["end%d" % c] = _relu(end, c)
134 keys = list(feat)
135 for k in keys:
136 if k in {'dim', 'cpu', 'size'}:
137 continue
138 feat['r%s' % k] = float(feat[k] / total)
140 for c in [2, 4, 8, 16, 32, 64]:
141 feat["iend%d" % c] = float(end >= c)
142 feat["ibegin%d" % c] = float(begin >= c)
144 # feat['CST'] = 1
145 feat['CST_'] = -1
146 feat['dbegin'] = - dbegin
147 feat['dend'] = - dend
149 keys = list(feat)
150 for k in keys:
151 if k.startswith('end') or k.startswith('begin'):
152 feat[k] = - feat[k]
153 elif k.startswith('rend') or k.startswith('rbegin'):
154 feat[k] = - feat[k]
155 elif k.startswith('iend') or k.startswith('ibegin'):
156 feat[k] = - feat[k]
157 elif k == "rdiscont":
158 feat[k] = - feat[k]
160 idp = list(range(len(perm)))
161 feat["rot"] = -1 if _is_rotation(perm) else 0
162 feat["rev"] = 1 if perm == tuple(idp[::-1]) else 0
163 feat["edit"] = _edit_distance(idp, perm)
164 feat["redit"] = feat["edit"] / len(idp)
165 return feat
168def predict_transposition_cost(shape, perm, coefs=None):
169 """
170 Given a shape and a permutation, predicts the cost of the
171 transposition.
173 :param shape: shape
174 :param perm: permutation
175 :param coefs: trained coefficients or None to get
176 the default ones
177 :return: dictionary of features
179 .. runpython::
180 :showcode:
182 import pprint
183 from mlprodict.testing.einsum.einsum_ml import (
184 compute_transposition_features)
186 pprint.pprint(
187 compute_transposition_features((3, 5, 7), (2, 1, 0)))
188 """
189 if coefs is None:
190 coefs = _ml_transpose_coefs
191 feat = compute_transposition_features(shape, perm)
192 res = 0
193 for k, v in feat.items():
194 res += v * coefs[k]
195 return max(0., res / 1000)