Coverage for mlprodict/onnxrt/validate/validate_benchmark.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Measures time processing for ONNX models.
4"""
5import numpy
6from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version
7from ... import __version__ as ort_version
8from .validate_helper import default_time_kwargs, measure_time
11def make_n_rows(x, n, y=None):
12 """
13 Multiplies or reduces the rows of x to get
14 exactly *n* rows.
16 @param x matrix
17 @param n number of rows
18 @param y target (optional)
19 @return new matrix or two new matrices if y is not None
20 """
21 if n < x.shape[0]:
22 if y is None:
23 return x[:n].copy()
24 return x[:n].copy(), y[:n].copy()
25 if len(x.shape) < 2:
26 r = numpy.empty((n, ), dtype=x.dtype)
27 if y is not None:
28 ry = numpy.empty((n, ), dtype=y.dtype) # pragma: no cover
29 for i in range(0, n, x.shape[0]):
30 end = min(i + x.shape[0], n)
31 r[i: end] = x[0: end - i]
32 if y is not None:
33 ry[i: end] = y[0: end - i] # pragma: no cover
34 else:
35 r = numpy.empty((n, x.shape[1]), dtype=x.dtype)
36 if y is not None:
37 if len(y.shape) < 2:
38 ry = numpy.empty((n, ), dtype=y.dtype)
39 else:
40 ry = numpy.empty((n, y.shape[1]), dtype=y.dtype)
41 for i in range(0, n, x.shape[0]):
42 end = min(i + x.shape[0], n)
43 try:
44 r[i: end, :] = x[0: end - i, :]
45 except ValueError as e: # pragma: no cover
46 raise ValueError(
47 "Unexpected error: r.shape={} x.shape={} end={} i={}".format(
48 r.shape, x.shape, end, i)) from e
49 if y is not None:
50 if len(y.shape) < 2:
51 ry[i: end] = y[0: end - i]
52 else:
53 ry[i: end, :] = y[0: end - i, :]
54 if y is None:
55 return r
56 return r, ry
59def benchmark_fct(fct, X, time_limit=4, obs=None, node_time=False,
60 time_kwargs=None, skip_long_test=True):
61 """
62 Benchmarks a function which takes an array
63 as an input and changes the number of rows.
65 @param fct function to benchmark, signature
66 is `fct(xo)`
67 @param X array
68 @param time_limit above this time, measurement is stopped
69 @param obs all information available in a dictionary
70 @param node_time measure time execution for each node in the graph
71 @param time_kwargs to define a more precise way to measure a model
72 @param skip_long_test skips tests for high values of N if they seem too long
73 @return dictionary with the results
75 The function uses *obs* to reduce the number of tries it does.
76 :epkg:`sklearn:gaussian_process:GaussianProcessRegressor`
77 produces huge *NxN* if predict method is called
78 with ``return_cov=True``.
79 The default for *time_kwargs* is the following:
81 .. runpython::
82 :showcode:
83 :warningout: DeprecationWarning
85 from mlprodict.onnxrt.validate.validate_helper import default_time_kwargs
86 import pprint
87 pprint.pprint(default_time_kwargs())
89 See also notebook :ref:`onnxnodetimerst` to see how this function
90 can be used to measure time spent in each node.
91 """
92 if time_kwargs is None:
93 time_kwargs = default_time_kwargs() # pragma: no cover
95 def make(x, n):
96 return make_n_rows(x, n)
98 def allow(N, obs):
99 if obs is None:
100 return True # pragma: no cover
101 prob = obs['problem']
102 if "-cov" in prob and N > 1000:
103 return False # pragma: no cover
104 return True
106 Ns = list(sorted(time_kwargs))
107 res = {}
108 for N in Ns:
109 if not isinstance(N, int):
110 raise RuntimeError( # pragma: no cover
111 "time_kwargs ({}) is wrong:\n{}".format(
112 type(time_kwargs), time_kwargs))
113 if not allow(N, obs):
114 continue # pragma: no cover
115 x = make(X, N)
116 number = time_kwargs[N]['number']
117 repeat = time_kwargs[N]['repeat']
118 if node_time:
119 fct(x)
120 main = None
121 for __ in range(repeat):
122 agg = None
123 for _ in range(number):
124 ms = fct(x)[1]
125 if agg is None:
126 agg = ms
127 for row in agg:
128 row['N'] = N
129 else:
130 if len(agg) != len(ms):
131 raise RuntimeError( # pragma: no cover
132 "Not the same number of nodes {} != {}.".format(len(agg), len(ms)))
133 for a, b in zip(agg, ms):
134 a['time'] += b['time']
135 if main is None:
136 main = agg
137 else:
138 if len(agg) != len(main):
139 raise RuntimeError( # pragma: no cover
140 "Not the same number of nodes {} != {}.".format(len(agg), len(main)))
141 for a, b in zip(main, agg):
142 a['time'] += b['time']
143 a['max_time'] = max(
144 a.get('max_time', b['time']), b['time'])
145 a['min_time'] = min(
146 a.get('min_time', b['time']), b['time'])
147 for row in main:
148 row['repeat'] = repeat
149 row['number'] = number
150 row['time'] /= repeat * number
151 if 'max_time' in row:
152 row['max_time'] /= number
153 row['min_time'] /= number
154 else:
155 row['max_time'] = row['time'] # pragma: no cover
156 row['min_time'] = row['time'] # pragma: no cover
157 res[N] = main
158 else:
159 res[N] = measure_time(fct, x, repeat=repeat,
160 number=number, div_by_number=True)
161 if (skip_long_test and not node_time and
162 res[N] is not None and
163 res[N].get('ttime', time_limit) >= time_limit):
164 # too long
165 break # pragma: no cover
166 if node_time:
167 rows = []
168 for _, v in res.items():
169 rows.extend(v)
170 return rows
171 return res