Coverage for mlprodict/onnxrt/validate/side_by_side.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Helpers to compare executions.
4"""
5import copy
6import numpy
7from .validate_difference import measure_relative_difference
10def _side_by_side_by_values_inputs(sess, inputs, i):
11 if isinstance(sess, tuple) and inputs is None:
12 new_sess, new_inputs = sess
13 elif isinstance(inputs, list):
14 new_sess = sess
15 new_inputs = inputs[i]
16 else:
17 new_sess = sess
18 new_inputs = copy.deepcopy(inputs)
19 return new_sess, new_inputs
22def side_by_side_by_values(sessions, *args, inputs=None,
23 return_results=False, **kwargs):
24 """
25 Compares the execution of two sessions.
26 It calls method :meth:`OnnxInference.run
27 <mlprodict.onnxrt.onnx_inference.OnnxInference.run>`
28 with value ``intermediate=True`` and compares the results.
30 :param sessions: list of class @see cl OnnxInference
31 :param inputs: inputs
32 :param args: additional parameters for
33 :meth:`OnnxInference.run
34 <mlprodict.onnxrt.onnx_inference.OnnxInference.run`
35 :param return_results: if True, returns the results as well.
36 :param kwargs: additional parameters for
37 :meth:`OnnxInference.run
38 <mlprodict.onnxrt.onnx_inference.OnnxInference.run`
39 :return: list of dictionaries
41 The first session is considered as the baseline.
42 See notebook :ref:`onnxsbsrst` for an example.
43 If *inputs* is None, the function assumes
44 *sessions* is a list of *tuple(sessions, inputs)*
45 because sometimes inputs must be different.
47 .. versionchanged:: 0.7
48 Parameter *return_results* was added. The function
49 returns the execution order when available.
50 """
51 if not kwargs.get('intermediate', True):
52 raise ValueError( # pragma: no cover
53 "kwargs must not set intermediate to True")
54 kwargs['intermediate'] = True
55 verbose = kwargs.get('verbose', 0)
56 fLOG = kwargs.get('fLOG', None)
58 # run
59 results = []
60 orders = []
61 for i, sess in enumerate(sessions):
62 if (hasattr(sess, 'runtime') and hasattr(sess, 'inplace') and
63 sess.runtime in (None, 'python') and sess.inplace):
64 raise ValueError(
65 "You must disable the inplace mechanism in order to get "
66 "true results. See OnnxInference constructor.")
67 new_sess, new_inputs = _side_by_side_by_values_inputs(sess, inputs, i)
68 if verbose > 0 and fLOG:
69 fLOG( # pragma: no cover
70 '[side_by_side_by_values] run session {}/{}'.format(
71 i + 1, len(sessions)))
72 res = new_sess.run(new_inputs, *args, **kwargs)
73 order = new_sess.get_execution_order()
74 results.append([(k, v) for k, v in res.items()])
75 orders.append(order)
77 # same number of results?
78 rows = []
79 row = {"metric": "nb_results", 'step': -1}
80 for i, res in enumerate(results):
81 row["v[%d]" % i] = len(res)
82 mnd = min(map(len, results))
83 mxd = max(map(len, results))
84 row['cmp'] = 'OK' if mnd == mxd else '!='
85 rows.append(row)
87 merged = merge_results(results)
89 # analysis
90 for i in range(len(merged)): # pylint: disable=C0200
91 for metric in ('rel-diff', 'abs-diff'):
92 row = {'step': i}
93 name, res_row = merged[i]
94 row['name'] = name
95 row['metric'] = metric
97 vals = []
98 for j, r in enumerate(res_row):
99 order = orders[j]
100 if order is not None:
101 row['order[%d]' % j] = order.get(
102 ('res', name), (numpy.nan, ))[0]
103 row['value[%d]' % j] = r
104 if hasattr(r, 'shape'):
105 row['shape[%d]' % j] = r.shape
107 if j == 0:
108 row['v[%d]' % j] = 0
109 elif res_row[0] is not None and r is not None:
110 v = measure_relative_difference(
111 res_row[0], r, abs_diff=metric == 'abs-diff')
112 row['v[%d]' % j] = v
113 vals.append(v)
115 if len(vals) > 0:
116 diff = max(vals)
117 if diff < 1e-5:
118 row['cmp'] = 'OK'
119 elif diff < 0.0001: # pragma: no cover
120 row['cmp'] = 'e<0.0001'
121 elif diff < 0.001: # pragma: no cover
122 row['cmp'] = 'e<0.001'
123 elif diff < 0.01: # pragma: no cover
124 row['cmp'] = 'e<0.01'
125 elif diff < 0.1: # pragma: no cover
126 row['cmp'] = 'e<0.1'
127 else: # pragma: no cover
128 row['cmp'] = "ERROR->=%1.1f" % diff
130 rows.append(row)
131 if return_results:
132 return rows, results
133 return rows
136def merge_results(results):
137 """
138 Merges results by name. The first ones
139 are used to keep the order.
141 :param results: results of intermediate variables
142 :return: list of tuple
143 """
144 # matrix of names
145 rows = [(k, []) for k, _ in results[0]]
146 positions = {k[0]: i for i, k in enumerate(rows)}
147 todos = []
148 for result in results:
149 todo = []
150 for row in rows:
151 row[1].append(None)
152 for i, (k, v) in enumerate(result):
153 pos = positions.get(k, None)
154 if pos is None:
155 todo.append((i, k, v))
156 else:
157 rows[pos][1][-1] = (v, i)
158 todos.append(todo)
160 # left over
161 if len(todos) > 0:
162 for i, todo in enumerate(todos):
163 if len(todo) == 0:
164 continue
165 for pos, name, val in todo:
166 pos1 = pos + 1
167 found = -1
168 for ik, row in enumerate(rows):
169 if row[1][i] is not None and row[1][i][1] == pos1:
170 found = ik
171 break
172 vv = [None] * len(results)
173 if found == -1:
174 vv[i] = (val, len(rows))
175 rows.append((name, vv))
176 else:
177 vv[i] = (val, pos)
178 rows.insert(found, (name, vv))
180 # final
181 final = []
182 for row in rows:
183 nrow = (row[0], [_ if _ is None else _[0] for _ in row[1]])
184 final.append(nrow)
185 return final