Coverage for mlprodict/plotting/plotting_benchmark.py: 100%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Useful plots.
4"""
5import numpy
8def heatmap(data, row_labels, col_labels, ax=None,
9 cbar_kw=None, cbarlabel=None, **kwargs):
10 """
11 Creates a heatmap from a numpy array and two lists of labels.
12 See @see fn plot_benchmark_metrics for an example.
14 @param data a 2D numpy array of shape (N, M).
15 @param row_labels a list or array of length N with the labels for the rows.
16 @param col_labels a list or array of length M with the labels for the columns.
17 @param ax a `matplotlib.axes.Axes` instance to which the heatmap is plotted,
18 if not provided, use current axes or create a new one. Optional.
19 @param cbar_kw a dictionary with arguments to `matplotlib.Figure.colorbar
20 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.colorbar.html>`_.
21 Optional.
22 @param cbarlabel the label for the colorbar. Optional.
23 @param kwargs all other arguments are forwarded to `imshow
24 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html>`_
25 @return ax, image, color bar
26 """
27 import matplotlib.pyplot as plt # delayed
29 if not ax:
30 ax = plt.gca() # pragma: no cover
32 # Plot the heatmap
33 im = ax.imshow(data, **kwargs)
35 # Create colorbar
36 if cbar_kw is None:
37 cbar_kw = {}
38 cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
39 if cbarlabel is not None:
40 cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
42 # We want to show all ticks...
43 ax.set_xticks(numpy.arange(data.shape[1]))
44 ax.set_yticks(numpy.arange(data.shape[0]))
45 # ... and label them with the respective list entries.
46 ax.set_xticklabels(col_labels)
47 ax.set_yticklabels(row_labels)
49 # Let the horizontal axes labeling appear on top.
50 ax.tick_params(top=True, bottom=False,
51 labeltop=True, labelbottom=False)
53 # Rotate the tick labels and set their alignment.
54 plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
55 rotation_mode="anchor")
57 # Turn spines off and create white grid.
58 for _, spine in ax.spines.items():
59 spine.set_visible(False)
61 ax.set_xticks(numpy.arange(data.shape[1] + 1) - .5, minor=True)
62 ax.set_yticks(numpy.arange(data.shape[0] + 1) - .5, minor=True)
63 ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
64 ax.tick_params(which="minor", bottom=False, left=False)
65 return ax, im, cbar
68def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
69 textcolors=("black", "black"),
70 threshold=None, **textkw):
71 """
72 Annotates a heatmap.
73 See @see fn plot_benchmark_metrics for an example.
75 @param im the *AxesImage* to be labeled.
76 @param data data used to annotate. If None, the image's data is used. Optional.
77 @param valfmt the format of the annotations inside the heatmap. This should either
78 use the string format method, e.g. `"$ {x:.2f}"`, or be a
79 `matplotlib.ticker.Formatter
80 <https://matplotlib.org/api/ticker_api.html>`_. Optional.
81 @param textcolors a list or array of two color specifications. The first is used for
82 values below a threshold, the second for those above. Optional.
83 @param threshold value in data units according to which the colors from textcolors are
84 applied. If None (the default) uses the middle of the colormap as
85 separation. Optional.
86 @param textkw all other arguments are forwarded to each call to `text` used to create
87 the text labels.
88 @return annotated objects
89 """
90 if not isinstance(data, (list, numpy.ndarray)):
91 data = im.get_array()
92 if threshold is not None:
93 threshold = im.norm(threshold) # pragma: no cover
94 else:
95 threshold = im.norm(data.max()) / 2.
97 kw = dict(horizontalalignment="center", verticalalignment="center")
98 kw.update(textkw)
100 # Get the formatter in case a string is supplied
101 if isinstance(valfmt, str):
102 import matplotlib # delayed
103 valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
105 texts = []
106 for i in range(data.shape[0]):
107 for j in range(data.shape[1]):
108 kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
109 text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
110 texts.append(text)
112 return texts
115def plot_benchmark_metrics(metric, xlabel=None, ylabel=None,
116 middle=1., transpose=False, ax=None,
117 cbar_kw=None, cbarlabel=None,
118 valfmt="{x:.2f}x"):
119 """
120 Plots a heatmap which represents a benchmark.
121 See example below.
123 @param metric dictionary ``{ (x,y): value }``
124 @param xlabel x label
125 @param ylabel y label
126 @param middle force the white color to be this value
127 @param transpose switches *x* and *y*
128 @param ax axis to borrow
129 @param cbar_kw a dictionary with arguments to `matplotlib.Figure.colorbar
130 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.colorbar.html>`_.
131 Optional.
132 @param cbarlabel the label for the colorbar. Optional.
133 @param valfmt format for the annotations
134 @return ax, colorbar
136 .. exref::
137 :title: Plot benchmark improvments
138 :lid: plot-2d-benchmark-metric
140 .. plot::
142 import matplotlib.pyplot as plt
143 from mlprodict.plotting.plotting_benchmark import plot_benchmark_metrics
145 data = {(1, 1): 0.1, (10, 1): 1, (1, 10): 2,
146 (10, 10): 100, (100, 1): 100, (100, 10): 1000}
148 fig, ax = plt.subplots(1, 2, figsize=(10, 4))
149 plot_benchmark_metrics(data, ax=ax[0], cbar_kw={'shrink': 0.6})
150 plot_benchmark_metrics(data, ax=ax[1], transpose=True,
151 xlabel='X', ylabel='Y',
152 cbarlabel="ratio")
153 plt.show()
154 """
155 if transpose:
156 metric = {(k[1], k[0]): v for k, v in metric.items()}
157 return plot_benchmark_metrics(metric, ax=ax, xlabel=ylabel, ylabel=xlabel,
158 middle=middle, transpose=False,
159 cbar_kw=cbar_kw, cbarlabel=cbarlabel)
161 from matplotlib.colors import LogNorm # delayed
163 x = numpy.array(list(sorted(set(k[0] for k in metric))))
164 y = numpy.array(list(sorted(set(k[1] for k in metric))))
165 rx = {v: i for i, v in enumerate(x)}
166 ry = {v: i for i, v in enumerate(y)}
168 X, _ = numpy.meshgrid(x, y)
169 zm = numpy.zeros(X.shape, dtype=numpy.float64)
170 for k, v in metric.items():
171 zm[ry[k[1]], rx[k[0]]] = v
173 xs = [str(_) for _ in x]
174 ys = [str(_) for _ in y]
175 vmin = min(metric.values())
176 vmax = max(metric.values())
177 if middle is not None:
178 v1 = middle / vmin
179 v2 = middle / vmax
180 vmin = min(vmin, v2)
181 vmax = max(vmax, v1)
182 ax, im, cbar = heatmap(zm, ys, xs, ax=ax, cmap="bwr",
183 norm=LogNorm(vmin=vmin, vmax=vmax),
184 cbarlabel=cbarlabel, cbar_kw=cbar_kw)
185 annotate_heatmap(im, valfmt=valfmt)
186 if xlabel is not None:
187 ax.set_xlabel(xlabel)
188 if ylabel is not None:
189 ax.set_ylabel(ylabel)
190 return ax, cbar