Note
Click here to download the full example code
Compares implementations of Where#
This example compares implementations of function numpy.where from numpy, onnxruntime. tensorflow and pytorch are included as well if available. The benchmark also compares the operator where to an equivalent implementation where(c, x, y) = x * c - y * (c - 1).
Available optimisation#
import numpy
import pandas
import matplotlib.pyplot as plt
from onnxruntime import InferenceSession
from skl2onnx.common.data_types import FloatTensorType, BooleanTensorType
from skl2onnx.algebra.onnx_ops import OnnxWhere, OnnxSub, OnnxMul
from cpyquickhelper.numbers import measure_time
from tqdm import tqdm
from mlprodict.testing.experimental_c_impl.experimental_c import code_optimisation
print(code_optimisation())
Out:
AVX-omp=8
Where: common code#
try:
from tensorflow import where as tf_where, convert_to_tensor
except ImportError:
tf_where = None
try:
from torch import where as torch_where, from_numpy
except ImportError:
torch_where = None
def build_ort_where(op_version=12):
node = OnnxWhere('cond', 'x', 'y', op_version=op_version,
output_names=['z'])
onx = node.to_onnx(inputs=[('cond', BooleanTensorType()),
('x', FloatTensorType()),
('y', FloatTensorType())],
target_opset=op_version)
sess = InferenceSession(onx.SerializeToString())
return lambda cond, x, y: sess.run(None, {'cond': cond, 'x': x, 'y': y})
def build_ort_where_add(op_version=12):
node = OnnxSub(
OnnxMul('x', 'cond', op_version=op_version),
OnnxMul('y',
OnnxSub('cond', numpy.array([1], dtype=numpy.float32),
op_version=op_version),
op_version=op_version),
op_version=op_version, output_names=['z'])
onx = node.to_onnx(inputs=[('cond', FloatTensorType()),
('x', FloatTensorType()),
('y', FloatTensorType())],
target_opset=op_version)
sess = InferenceSession(onx.SerializeToString())
return lambda cond, x, y: sess.run(None, {'cond': cond, 'x': x, 'y': y})
def numpy_where_add(cond, x, y):
cx = x * cond
cy = cond - 1
numpy.multiply(y, cy, out=y)
numpy.subtract(cx, cy, out=cx)
return cx
def loop_where(fct, conds, xs, ys):
for cond, x, y in zip(conds, xs, ys):
fct(cond, x, y)
def benchmark_equation():
# equations
ort_where = build_ort_where()
ort_where_add = build_ort_where_add()
res = []
for dim in tqdm([8, 16, 32, 64, 100, 128, 200,
256, 500, 512, 1024, 2048]):
repeat = 5
number = 10
conds = [(numpy.random.rand(dim, dim) < 0.5).astype(numpy.bool_)
for _ in range(repeat)]
xs = [numpy.random.rand(dim, dim).astype(numpy.float32)
for _ in range(repeat)]
ys = [numpy.random.rand(dim, dim).astype(numpy.float32)
for _ in range(repeat)]
# numpy
ctx = dict(conds=conds, xs=xs, ys=ys, where=numpy.where,
loop_where=loop_where)
obs = measure_time(
"loop_where(where, conds, xs, ys)",
div_by_number=True, context=ctx, repeat=repeat, number=number)
obs['dim'] = dim
obs['fct'] = 'numpy.where'
res.append(obs)
# numpy add
ctx['where'] = numpy_where_add
obs = measure_time(
"loop_where(where, conds, xs, ys)",
div_by_number=True, context=ctx, repeat=repeat, number=number)
obs['dim'] = dim
obs['fct'] = 'numpy_where_add'
res.append(obs)
# onnxruntime
ctx['where'] = ort_where
obs = measure_time(
"loop_where(where, conds, xs, ys)",
div_by_number=True, context=ctx, repeat=repeat, number=number)
obs['dim'] = dim
obs['fct'] = 'ort_where'
res.append(obs)
# onnxruntime - 2
ctx['where'] = ort_where_add
ctx['conds'] = [c.astype(numpy.float32) for c in conds]
obs = measure_time(
"loop_where(where, conds, xs, ys)",
div_by_number=True, context=ctx, repeat=repeat, number=number)
obs['dim'] = dim
obs['fct'] = 'ort_where_add'
res.append(obs)
if tf_where is not None:
# tensorflow
ctx['where'] = tf_where
ctx['conds'] = [convert_to_tensor(c) for c in conds]
ctx['xs'] = [convert_to_tensor(x) for x in xs]
ctx['ys'] = [convert_to_tensor(y) for y in ys]
obs = measure_time(
"loop_where(where, conds, xs, ys)",
div_by_number=True, context=ctx, repeat=repeat, number=number)
obs['dim'] = dim
obs['fct'] = 'tf_where'
res.append(obs)
if torch_where is not None:
# torch
ctx['where'] = torch_where
ctx['conds'] = [from_numpy(c) for c in conds]
ctx['xs'] = [from_numpy(x) for x in xs]
ctx['ys'] = [from_numpy(y) for y in ys]
obs = measure_time(
"loop_where(where, conds, xs, ys)",
div_by_number=True, context=ctx, repeat=repeat, number=number)
obs['dim'] = dim
obs['fct'] = 'torch_where'
res.append(obs)
# Dataframes
df = pandas.DataFrame(res)
piv = df.pivot('dim', 'fct', 'average')
rs = piv.copy()
rs['ort_where'] = rs['numpy.where'] / rs['ort_where']
rs['numpy_where_add'] = rs['numpy.where'] / rs['numpy_where_add']
rs['ort_where_add'] = rs['numpy.where'] / rs['ort_where_add']
if 'tf_where' in rs.columns:
rs['tf_where'] = rs['numpy.where'] / rs['tf_where']
if 'torch_where' in rs.columns:
rs['torch_where'] = rs['numpy.where'] / rs['torch_where']
rs['numpy.where'] = 1.
# Graphs.
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
piv.plot(logx=True, logy=True, ax=ax[0],
title="Where benchmark -- (N, N)\nlower better")
ax[0].legend(prop={"size": 9})
rs.plot(logx=True, logy=True, ax=ax[1],
title="Where Speedup, baseline=numpy -- (N, N)\nhigher better")
ax[1].plot([min(rs.index), max(rs.index)], [0.5, 0.5], 'g--')
ax[1].plot([min(rs.index), max(rs.index)], [2., 2.], 'g--')
ax[1].legend(prop={"size": 9})
return df, rs, ax
############
# Benchmark
# +++++++++
df, piv, ax = benchmark_equation()
df.pivot("fct", "dim", "average")
dfs = [df]
Out:
0%| | 0/12 [00:00<?, ?it/s]
8%|8 | 1/12 [00:00<00:03, 2.87it/s]
25%|##5 | 3/12 [00:00<00:01, 6.28it/s]
33%|###3 | 4/12 [00:00<00:01, 6.46it/s]
42%|####1 | 5/12 [00:00<00:01, 5.47it/s]
50%|##### | 6/12 [00:01<00:01, 4.25it/s]
58%|#####8 | 7/12 [00:03<00:04, 1.12it/s]
67%|######6 | 8/12 [00:06<00:06, 1.53s/it]
75%|#######5 | 9/12 [00:12<00:08, 2.81s/it]
83%|########3 | 10/12 [00:17<00:07, 3.66s/it]
92%|#########1| 11/12 [00:34<00:07, 7.72s/it]
100%|##########| 12/12 [01:35<00:00, 23.68s/it]
100%|##########| 12/12 [01:35<00:00, 7.95s/it]
Conclusion#
The implementation of Where should be faster than the formula where(c, x, y) = x * c - y * (c - 1).
merged = pandas.concat(dfs)
name = "where"
merged.to_csv("plot_%s.csv" % name, index=False)
merged.to_excel("plot_%s.xlsx" % name, index=False)
plt.savefig("plot_%s.png" % name)
plt.show()
Total running time of the script: ( 1 minutes 38.319 seconds)