Coverage for mlprodict/plotting/plotting_validate_graph.py: 97%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Functions to help visualizing performances.
4"""
5import numpy
6import pandas
9def _model_name(name):
10 """
11 Extracts the main component of a model, removes
12 suffixes such ``Classifier``, ``Regressor``, ``CV``.
14 @param name string
15 @return shorter string
16 """
17 if name.startswith("Select"):
18 return "Select"
19 if name.startswith("Nu"):
20 return "Nu"
21 modif = 1
22 while modif > 0:
23 modif = 0
24 for suf in ['Classifier', 'Regressor', 'CV', 'IC',
25 'Transformer']:
26 if name.endswith(suf):
27 name = name[:-len(suf)]
28 modif += 1
29 return name
32def plot_validate_benchmark(df):
33 """
34 Plots a graph which summarizes the performances of a benchmark
35 validating a runtime for :epkg:`ONNX`.
37 @param df output of function @see fn summary_report
38 @return fig, ax
40 .. plot::
42 from logging import getLogger
43 from pandas import DataFrame
44 import matplotlib.pyplot as plt
45 from mlprodict.onnxrt.validate import enumerate_validated_operator_opsets, summary_report
46 from mlprodict.tools.plotting import plot_validate_benchmark
48 rows = list(enumerate_validated_operator_opsets(
49 verbose=0, models={"LinearRegression"}, opset_min=11,
50 runtime=['python', 'onnxruntime1'], debug=False,
51 benchmark=True, n_features=[None, 10]))
53 df = DataFrame(rows)
54 piv = summary_report(df)
55 fig, ax = plot_validate_benchmark(piv)
56 plt.show()
57 """
58 import matplotlib.pyplot as plt
60 if 'n_features' not in df.columns:
61 df["n_features"] = numpy.nan # pragma: no cover
62 if 'runtime' not in df.columns:
63 df['runtime'] = '?' # pragma: no cover
65 fmt = "{} [{}-{}|{}] D{}"
66 df["label"] = df.apply(
67 lambda row: fmt.format(
68 row["name"], row["problem"], row["scenario"],
69 row['optim'], row["n_features"]).replace("-default|", "-**]"), axis=1)
70 df = df.sort_values(["name", "problem", "scenario", "optim",
71 "n_features", "runtime"],
72 ascending=False).reset_index(drop=True).copy()
73 indices = ['label', 'runtime']
74 values = [c for c in df.columns
75 if 'N=' in c and '-min' not in c and '-max' not in c]
76 try:
77 df = df[indices + values]
78 except KeyError as e: # pragma: no cover
79 raise RuntimeError(
80 "Unable to find the following columns {}\nin {}".format(
81 indices + values, df.columns)) from e
83 if 'RT/SKL-N=1' not in df.columns:
84 raise RuntimeError( # pragma: no cover
85 "Column 'RT/SKL-N=1' is missing, benchmark was probably not run.")
86 na = df["RT/SKL-N=1"].isnull()
87 dfp = df[~na]
88 runtimes = list(sorted(set(dfp['runtime'])))
89 final = None
90 for rt in runtimes:
91 sub = dfp[dfp.runtime == rt].drop('runtime', axis=1).copy()
92 col = list(sub.columns)
93 for i in range(1, len(col)):
94 col[i] += "__" + rt
95 sub.columns = col
97 if final is None:
98 final = sub
99 else:
100 final = final.merge(sub, on='label', how='outer')
102 # let's add average and median
103 ncol = (final.shape[1] - 1) // len(runtimes)
104 if len(runtimes) + 1 > final.shape[0]:
105 dfp_legend = final.iloc[:len(runtimes) + 1, :].copy()
106 while dfp_legend.shape[0] < len(runtimes) + 1:
107 dfp_legend = pandas.concat([dfp_legend, dfp_legend[:1]])
108 else:
109 dfp_legend = final.iloc[:len(runtimes) + 1, :].copy()
110 rleg = dfp_legend.copy()
111 dfp_legend.iloc[:, 1:] = numpy.nan
112 rleg.iloc[:, 1:] = numpy.nan
114 for r, runt in enumerate(runtimes):
115 sli = slice(1 + ncol * r, 1 + ncol * r + ncol)
116 cm = final.iloc[:, sli].mean().values
117 dfp_legend.iloc[r + 1, sli] = cm
118 rleg.iloc[r, sli] = final.iloc[:, sli].median()
119 dfp_legend.iloc[r + 1, 0] = "avg_" + runt
120 rleg.iloc[r, 0] = "med_" + runt
121 dfp_legend.iloc[0, 0] = "------"
122 rleg.iloc[-1, 0] = "------"
124 # sort
125 final = final.sort_values('label', ascending=False).copy()
127 # add global statistics
128 final = pandas.concat([rleg, final, dfp_legend]).reset_index(drop=True)
130 # graph beginning
131 total = final.shape[0] * 0.45
132 fig, ax = plt.subplots(1, len(values), figsize=(14, total),
133 sharex=False, sharey=True)
134 x = numpy.arange(final.shape[0])
135 subh = 1.0 / len(runtimes)
136 height = total / final.shape[0] * (subh + 0.1)
137 decrt = {rt: height * i for i, rt in enumerate(runtimes)}
138 colors = {rt: c for rt, c in zip(
139 runtimes, ['blue', 'orange', 'cyan', 'yellow'])}
141 # draw lines between models
142 vals = final.iloc[:, 1:].values.ravel()
143 xlim = [min(0.5, min(vals)), max(2, max(vals))]
144 while i < final.shape[0] - 1:
145 i += 1
146 label = final.iloc[i, 0]
147 if '[' not in label:
148 continue
149 prev = final.iloc[i - 1, 0]
150 if '[' not in label:
151 continue # pragma: no cover
152 label = label.split()[0]
153 prev = prev.split()[0]
154 if _model_name(label) == _model_name(prev):
155 continue
157 blank = final.iloc[:1, :].copy()
158 blank.iloc[0, 0] = '------'
159 blank.iloc[0, 1:] = xlim[0]
160 final = pandas.concat([final[:i], blank, final[i:]])
161 i += 1
163 final = final.reset_index(drop=True).copy()
164 x = numpy.arange(final.shape[0])
166 done = set()
167 for c in final.columns[1:]:
168 place, runtime = c.split('__')
169 if hasattr(ax, 'shape'):
170 index = values.index(place)
171 if (index, runtime) in done:
172 raise RuntimeError( # pragma: no cover
173 "Issue with column '{}'\nlabels={}\nruntimes={}\ncolumns="
174 "{}\nvalues={}\n{}".format(
175 c, list(final.label), runtimes, final.columns, values, final))
176 axi = ax[index]
177 done.add((index, runtime))
178 else:
179 if (0, runtime) in done: # pragma: no cover
180 raise RuntimeError(
181 "Issue with column '{}'\nlabels={}\nruntimes={}\ncolumns="
182 "{}\nvalues={}\n{}".format(
183 c, final.label, runtimes, final.columns, values, final))
184 done.add((0, runtime)) # pragma: no cover
185 axi = ax # pragma: no cover
186 if c in final.columns:
187 yl = final.loc[:, c]
188 xl = x + decrt[runtime] / 2
189 axi.barh(xl, yl, label=runtime, height=height,
190 color=colors[runtime])
191 axi.set_title(place)
193 def _plot_axis(axi, x, xlim):
194 axi.plot([1, 1], [0, max(x)], 'g-')
195 axi.plot([2, 2], [0, max(x)], 'r--')
196 axi.set_xlim(xlim)
197 axi.set_xscale('log')
198 axi.set_ylim([min(x) - 2, max(x) + 1])
200 def _plot_final(axi, x, final):
201 axi.set_yticks(x)
202 axi.set_yticklabels(final['label'])
204 if hasattr(ax, 'shape'):
205 for i in range(len(ax)): # pylint: disable=C0200
206 _plot_axis(ax[i], x, xlim)
208 ax[min(ax.shape[0] - 1, 2)].legend()
209 _plot_final(ax[0], x, final)
210 else: # pragma: no cover
211 _plot_axis(ax, x, xlim)
212 _plot_final(ax, x, final)
213 ax.legend()
215 fig.subplots_adjust(left=0.25)
216 return fig, ax