Coverage for mlprodict/plotting/plotting_validate_graph.py: 97%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

116 statements  

1""" 

2@file 

3@brief Functions to help visualizing performances. 

4""" 

5import numpy 

6import pandas 

7 

8 

9def _model_name(name): 

10 """ 

11 Extracts the main component of a model, removes 

12 suffixes such ``Classifier``, ``Regressor``, ``CV``. 

13 

14 @param name string 

15 @return shorter string 

16 """ 

17 if name.startswith("Select"): 

18 return "Select" 

19 if name.startswith("Nu"): 

20 return "Nu" 

21 modif = 1 

22 while modif > 0: 

23 modif = 0 

24 for suf in ['Classifier', 'Regressor', 'CV', 'IC', 

25 'Transformer']: 

26 if name.endswith(suf): 

27 name = name[:-len(suf)] 

28 modif += 1 

29 return name 

30 

31 

32def plot_validate_benchmark(df): 

33 """ 

34 Plots a graph which summarizes the performances of a benchmark 

35 validating a runtime for :epkg:`ONNX`. 

36 

37 @param df output of function @see fn summary_report 

38 @return fig, ax 

39 

40 .. plot:: 

41 

42 from logging import getLogger 

43 from pandas import DataFrame 

44 import matplotlib.pyplot as plt 

45 from mlprodict.onnxrt.validate import enumerate_validated_operator_opsets, summary_report 

46 from mlprodict.tools.plotting import plot_validate_benchmark 

47 

48 rows = list(enumerate_validated_operator_opsets( 

49 verbose=0, models={"LinearRegression"}, opset_min=11, 

50 runtime=['python', 'onnxruntime1'], debug=False, 

51 benchmark=True, n_features=[None, 10])) 

52 

53 df = DataFrame(rows) 

54 piv = summary_report(df) 

55 fig, ax = plot_validate_benchmark(piv) 

56 plt.show() 

57 """ 

58 import matplotlib.pyplot as plt 

59 

60 if 'n_features' not in df.columns: 

61 df["n_features"] = numpy.nan # pragma: no cover 

62 if 'runtime' not in df.columns: 

63 df['runtime'] = '?' # pragma: no cover 

64 

65 fmt = "{} [{}-{}|{}] D{}" 

66 df["label"] = df.apply( 

67 lambda row: fmt.format( 

68 row["name"], row["problem"], row["scenario"], 

69 row['optim'], row["n_features"]).replace("-default|", "-**]"), axis=1) 

70 df = df.sort_values(["name", "problem", "scenario", "optim", 

71 "n_features", "runtime"], 

72 ascending=False).reset_index(drop=True).copy() 

73 indices = ['label', 'runtime'] 

74 values = [c for c in df.columns 

75 if 'N=' in c and '-min' not in c and '-max' not in c] 

76 try: 

77 df = df[indices + values] 

78 except KeyError as e: # pragma: no cover 

79 raise RuntimeError( 

80 "Unable to find the following columns {}\nin {}".format( 

81 indices + values, df.columns)) from e 

82 

83 if 'RT/SKL-N=1' not in df.columns: 

84 raise RuntimeError( # pragma: no cover 

85 "Column 'RT/SKL-N=1' is missing, benchmark was probably not run.") 

86 na = df["RT/SKL-N=1"].isnull() 

87 dfp = df[~na] 

88 runtimes = list(sorted(set(dfp['runtime']))) 

89 final = None 

90 for rt in runtimes: 

91 sub = dfp[dfp.runtime == rt].drop('runtime', axis=1).copy() 

92 col = list(sub.columns) 

93 for i in range(1, len(col)): 

94 col[i] += "__" + rt 

95 sub.columns = col 

96 

97 if final is None: 

98 final = sub 

99 else: 

100 final = final.merge(sub, on='label', how='outer') 

101 

102 # let's add average and median 

103 ncol = (final.shape[1] - 1) // len(runtimes) 

104 if len(runtimes) + 1 > final.shape[0]: 

105 dfp_legend = final.iloc[:len(runtimes) + 1, :].copy() 

106 while dfp_legend.shape[0] < len(runtimes) + 1: 

107 dfp_legend = pandas.concat([dfp_legend, dfp_legend[:1]]) 

108 else: 

109 dfp_legend = final.iloc[:len(runtimes) + 1, :].copy() 

110 rleg = dfp_legend.copy() 

111 dfp_legend.iloc[:, 1:] = numpy.nan 

112 rleg.iloc[:, 1:] = numpy.nan 

113 

114 for r, runt in enumerate(runtimes): 

115 sli = slice(1 + ncol * r, 1 + ncol * r + ncol) 

116 cm = final.iloc[:, sli].mean().values 

117 dfp_legend.iloc[r + 1, sli] = cm 

118 rleg.iloc[r, sli] = final.iloc[:, sli].median() 

119 dfp_legend.iloc[r + 1, 0] = "avg_" + runt 

120 rleg.iloc[r, 0] = "med_" + runt 

121 dfp_legend.iloc[0, 0] = "------" 

122 rleg.iloc[-1, 0] = "------" 

123 

124 # sort 

125 final = final.sort_values('label', ascending=False).copy() 

126 

127 # add global statistics 

128 final = pandas.concat([rleg, final, dfp_legend]).reset_index(drop=True) 

129 

130 # graph beginning 

131 total = final.shape[0] * 0.45 

132 fig, ax = plt.subplots(1, len(values), figsize=(14, total), 

133 sharex=False, sharey=True) 

134 x = numpy.arange(final.shape[0]) 

135 subh = 1.0 / len(runtimes) 

136 height = total / final.shape[0] * (subh + 0.1) 

137 decrt = {rt: height * i for i, rt in enumerate(runtimes)} 

138 colors = {rt: c for rt, c in zip( 

139 runtimes, ['blue', 'orange', 'cyan', 'yellow'])} 

140 

141 # draw lines between models 

142 vals = final.iloc[:, 1:].values.ravel() 

143 xlim = [min(0.5, min(vals)), max(2, max(vals))] 

144 while i < final.shape[0] - 1: 

145 i += 1 

146 label = final.iloc[i, 0] 

147 if '[' not in label: 

148 continue 

149 prev = final.iloc[i - 1, 0] 

150 if '[' not in label: 

151 continue # pragma: no cover 

152 label = label.split()[0] 

153 prev = prev.split()[0] 

154 if _model_name(label) == _model_name(prev): 

155 continue 

156 

157 blank = final.iloc[:1, :].copy() 

158 blank.iloc[0, 0] = '------' 

159 blank.iloc[0, 1:] = xlim[0] 

160 final = pandas.concat([final[:i], blank, final[i:]]) 

161 i += 1 

162 

163 final = final.reset_index(drop=True).copy() 

164 x = numpy.arange(final.shape[0]) 

165 

166 done = set() 

167 for c in final.columns[1:]: 

168 place, runtime = c.split('__') 

169 if hasattr(ax, 'shape'): 

170 index = values.index(place) 

171 if (index, runtime) in done: 

172 raise RuntimeError( # pragma: no cover 

173 "Issue with column '{}'\nlabels={}\nruntimes={}\ncolumns=" 

174 "{}\nvalues={}\n{}".format( 

175 c, list(final.label), runtimes, final.columns, values, final)) 

176 axi = ax[index] 

177 done.add((index, runtime)) 

178 else: 

179 if (0, runtime) in done: # pragma: no cover 

180 raise RuntimeError( 

181 "Issue with column '{}'\nlabels={}\nruntimes={}\ncolumns=" 

182 "{}\nvalues={}\n{}".format( 

183 c, final.label, runtimes, final.columns, values, final)) 

184 done.add((0, runtime)) # pragma: no cover 

185 axi = ax # pragma: no cover 

186 if c in final.columns: 

187 yl = final.loc[:, c] 

188 xl = x + decrt[runtime] / 2 

189 axi.barh(xl, yl, label=runtime, height=height, 

190 color=colors[runtime]) 

191 axi.set_title(place) 

192 

193 def _plot_axis(axi, x, xlim): 

194 axi.plot([1, 1], [0, max(x)], 'g-') 

195 axi.plot([2, 2], [0, max(x)], 'r--') 

196 axi.set_xlim(xlim) 

197 axi.set_xscale('log') 

198 axi.set_ylim([min(x) - 2, max(x) + 1]) 

199 

200 def _plot_final(axi, x, final): 

201 axi.set_yticks(x) 

202 axi.set_yticklabels(final['label']) 

203 

204 if hasattr(ax, 'shape'): 

205 for i in range(len(ax)): # pylint: disable=C0200 

206 _plot_axis(ax[i], x, xlim) 

207 

208 ax[min(ax.shape[0] - 1, 2)].legend() 

209 _plot_final(ax[0], x, final) 

210 else: # pragma: no cover 

211 _plot_axis(ax, x, xlim) 

212 _plot_final(ax, x, final) 

213 ax.legend() 

214 

215 fig.subplots_adjust(left=0.25) 

216 return fig, ax