Compares numpy to onnxruntime on simple functions

onnxruntime can be used a replacement to numpy. It can be used to implement a training algorithm, onnxruntime-training differentiate an onnx graph and runs it to compute the gradient. Simple functions are implemented in ONNX and ran with onnxruntime to update the weights. function_onnx_graph returns many functions used to implement a training algorithm. The following benchmarks compares a couple of implementations:

  • numpy: an implementation based on numpy, not optimized

  • sess: inference through an ONNX graph executed with method onnxruntime.InferenceSession.run

  • bind: inference through an ONNX graph executed with method onnxruntime.InferenceSession.run_with_iobinding

  • run: inference through an ONNX graph executed with method onnxruntime.InferenceSession.run_with_iobinding but without counting the binding assuming input buffers are reused and do not need binding again

axpy

This function implements Y = f(X1, X2, \alpha) = \alpha X1 + X2.

import numpy
from scipy.special import expit
import pandas
from tqdm import tqdm
from cpyquickhelper.numbers.speed_measure import measure_time
import matplotlib.pyplot as plt
from onnxruntime import InferenceSession
from onnxruntime.capi._pybind_state import (  # pylint: disable=E0611
    SessionIOBinding, OrtDevice as C_OrtDevice,
    OrtValue as C_OrtValue)
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from onnxcustom.utils.onnx_function import function_onnx_graph

fct_onx = function_onnx_graph("axpy")
print(onnx_simple_text_plot(fct_onx))

Out:

opset: domain='' version=14
input: name='X1' type=dtype('float32') shape=()
input: name='X2' type=dtype('float32') shape=()
input: name='alpha' type=dtype('float32') shape=(1,)
Mul(X1, alpha) -> Mu_C0
  Add(Mu_C0, X2) -> Y
output: name='Y' type=dtype('float32') shape=()

The numpy implementation is the following.

fct_numpy = lambda X1, X2, alpha: X1 * alpha + X2

The benchmark

def reshape(a, dim):
    if len(a.shape) == 2:
        return a[:dim].copy()
    return a


def bind_and_run(sess, bind, names, args, out_names, device):
    for n, a in zip(names, args):
        bind.bind_ortvalue_input(n, a)
    for o in out_names:
        bind.bind_output(o, device)
    sess.run_with_iobinding(bind, None)
    return bind.get_outputs()


def nobind_just_run(sess, bind):
    sess.run_with_iobinding(bind, None)


def benchmark(name, onx, fct_numpy, *args,
              dims=(1, 10, 100, 200, 500, 1000, 2000, 10000)):
    sess = InferenceSession(onx.SerializeToString())
    device = C_OrtDevice(
        C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0)
    names = [i.name for i in sess.get_inputs()]
    out_names = [o.name for o in sess.get_outputs()]
    if len(names) != len(args):
        raise RuntimeError(
            "Size mismatch %d != %d." % (len(names), len(args)))

    rows = []
    for dim in tqdm(dims):
        new_args = [reshape(a, dim) for a in args]
        ortvalues = [
            C_OrtValue.ortvalue_from_numpy(a, device)
            for a in new_args]

        ms = measure_time(lambda: fct_numpy(*new_args),
                          repeat=50, number=100)
        ms.update(dict(name=name, impl='numpy', dim=dim))
        rows.append(ms)

        inps = {n: a for n, a in zip(names, new_args)}
        ms = measure_time(lambda: sess.run(None, inps))
        ms.update(dict(name=name, impl='sess', dim=dim))
        rows.append(ms)

        bind = SessionIOBinding(sess._sess)
        ms = measure_time(
            lambda: bind_and_run(
                sess._sess, bind, names, ortvalues, out_names, device))
        ms.update(dict(name=name, impl='bind_run', dim=dim))
        rows.append(ms)

        ms = measure_time(
            lambda: nobind_just_run(sess._sess, bind))
        ms.update(dict(name=name, impl='run', dim=dim))
        rows.append(ms)

    return rows

Back to function axpy.

rows = benchmark(
    'axpy', fct_onx, fct_numpy,
    numpy.random.randn(1000, 10).astype(numpy.float32),
    numpy.random.randn(1000, 10).astype(numpy.float32),
    numpy.array([0.5], dtype=numpy.float32))

all_rows = []
all_rows.extend(rows)
piv = pandas.DataFrame(rows).pivot('dim', 'impl', 'average')
piv

Out:

  0%|          | 0/8 [00:00<?, ?it/s]
 12%|#2        | 1/8 [00:00<00:01,  4.46it/s]
 25%|##5       | 2/8 [00:00<00:01,  4.42it/s]
 38%|###7      | 3/8 [00:00<00:01,  4.07it/s]
 50%|#####     | 4/8 [00:00<00:01,  3.88it/s]
 62%|######2   | 5/8 [00:01<00:00,  3.62it/s]
 75%|#######5  | 6/8 [00:01<00:00,  3.23it/s]
 88%|########7 | 7/8 [00:02<00:00,  3.02it/s]
100%|##########| 8/8 [00:02<00:00,  2.90it/s]
100%|##########| 8/8 [00:02<00:00,  3.29it/s]
impl bind_run numpy run sess
dim
1 0.004020 0.002472 0.001767 0.003834
10 0.004144 0.002464 0.001897 0.003902
100 0.004209 0.003188 0.001983 0.004204
200 0.004304 0.003303 0.002080 0.004269
500 0.004564 0.003738 0.002329 0.004624
1000 0.005001 0.004787 0.002745 0.005261
2000 0.004988 0.004789 0.002764 0.005207
10000 0.005004 0.004765 0.002741 0.005211


Graph.

name = rows[0]['name']
ax = piv.plot(logx=True, logy=True)
ax.set_title(name + "\nlower is better")
axpy lower is better

Out:

Text(0.5, 1.0, 'axpy\nlower is better')

axpyw

It does Y, Z = f(X1, X2, G, \alpha, \beta) = (Y, Z) where Z = \beta G + \alpha X1 and Y = Z + X2.

fct_onx = function_onnx_graph("axpyw")
print(onnx_simple_text_plot(fct_onx))

Out:

opset: domain='' version=14
input: name='X1' type=dtype('float32') shape=()
input: name='X2' type=dtype('float32') shape=()
input: name='G' type=dtype('float32') shape=()
input: name='alpha' type=dtype('float32') shape=(1,)
input: name='beta' type=dtype('float32') shape=(1,)
Mul(X1, alpha) -> Mu_C0
Mul(G, beta) -> Mu_C02
  Add(Mu_C0, Mu_C02) -> Z
    Add(Z, X2) -> Y
output: name='Y' type=dtype('float32') shape=()
output: name='Z' type=dtype('float32') shape=()

benchmark

fct_numpy = lambda x1, x2, g, alpha, beta: (
    x1 * alpha + x2 + beta * g, x1 * alpha + beta * g)

rows = benchmark(
    'axpyw', fct_onx, fct_numpy,
    numpy.random.randn(1000, 10).astype(numpy.float32),
    numpy.random.randn(1000, 10).astype(numpy.float32),
    numpy.random.randn(1000, 10).astype(numpy.float32),
    numpy.array([0.5], dtype=numpy.float32),
    numpy.array([0.5], dtype=numpy.float32))

all_rows.extend(rows)
piv = pandas.DataFrame(rows).pivot('dim', 'impl', 'average')
piv

Out:

  0%|          | 0/8 [00:00<?, ?it/s]
 12%|#2        | 1/8 [00:00<00:03,  1.90it/s]
 25%|##5       | 2/8 [00:01<00:03,  1.89it/s]
 38%|###7      | 3/8 [00:01<00:02,  1.71it/s]
 50%|#####     | 4/8 [00:02<00:02,  1.59it/s]
 62%|######2   | 5/8 [00:03<00:02,  1.44it/s]
 75%|#######5  | 6/8 [00:04<00:01,  1.25it/s]
 88%|########7 | 7/8 [00:05<00:00,  1.15it/s]
100%|##########| 8/8 [00:06<00:00,  1.10it/s]
100%|##########| 8/8 [00:06<00:00,  1.28it/s]
impl bind_run numpy run sess
dim
1 0.005410 0.008006 0.002002 0.004804
10 0.005511 0.008061 0.002049 0.004867
100 0.005652 0.010279 0.002217 0.005257
200 0.005859 0.011089 0.002417 0.005517
500 0.006377 0.012870 0.002917 0.006208
1000 0.007307 0.016126 0.003762 0.007902
2000 0.007499 0.016227 0.003889 0.007580
10000 0.007290 0.016235 0.003779 0.007562


Graph.

name = rows[0]['name']
ax = piv.plot(logx=True, logy=True)
ax.set_title(name + "\nlower is better")
axpyw lower is better

Out:

Text(0.5, 1.0, 'axpyw\nlower is better')

axpyw2

It implements Y, Z = f(X1, X2, G, \alpha, \beta) = (Y, Z) where Z = \beta G + \alpha X1 and Y = \beta * Z + \alpha X1 + X2.

fct_onx = function_onnx_graph("axpyw2")
print(onnx_simple_text_plot(fct_onx))

Out:

opset: domain='' version=14
input: name='X1' type=dtype('float32') shape=()
input: name='X2' type=dtype('float32') shape=()
input: name='G' type=dtype('float32') shape=()
input: name='alpha' type=dtype('float32') shape=(1,)
input: name='beta' type=dtype('float32') shape=(1,)
Mul(X1, alpha) -> Mu_C0
Mul(G, beta) -> Mu_C03
  Add(Mu_C0, Mu_C03) -> Z
    Mul(Z, beta) -> Mu_C02
  Add(Mu_C0, Mu_C02) -> Ad_C0
    Add(Ad_C0, X2) -> Y
output: name='Y' type=dtype('float32') shape=()
output: name='Z' type=dtype('float32') shape=()

benchmark

fct_numpy = lambda x1, x2, g, alpha, beta: (
    x1 * alpha + x2 + beta * (x1 * alpha + beta * g),
    x1 * alpha + beta * g)

rows = benchmark(
    'axpyw2', fct_onx, fct_numpy,
    numpy.random.randn(1000, 10).astype(numpy.float32),
    numpy.random.randn(1000, 10).astype(numpy.float32),
    numpy.random.randn(1000, 10).astype(numpy.float32),
    numpy.array([0.5], dtype=numpy.float32),
    numpy.array([0.5], dtype=numpy.float32))

all_rows.extend(rows)
piv = pandas.DataFrame(rows).pivot('dim', 'impl', 'average')
piv

Out:

  0%|          | 0/8 [00:00<?, ?it/s]
 12%|#2        | 1/8 [00:00<00:04,  1.42it/s]
 25%|##5       | 2/8 [00:01<00:04,  1.42it/s]
 38%|###7      | 3/8 [00:02<00:03,  1.27it/s]
 50%|#####     | 4/8 [00:03<00:03,  1.18it/s]
 62%|######2   | 5/8 [00:04<00:02,  1.06it/s]
 75%|#######5  | 6/8 [00:05<00:02,  1.09s/it]
 88%|########7 | 7/8 [00:07<00:01,  1.20s/it]
100%|##########| 8/8 [00:08<00:00,  1.27s/it]
100%|##########| 8/8 [00:08<00:00,  1.07s/it]
impl bind_run numpy run sess
dim
1 0.005715 0.011349 0.002276 0.005124
10 0.005753 0.011395 0.002312 0.005138
100 0.006031 0.014726 0.002566 0.005628
200 0.006341 0.015799 0.002866 0.006004
500 0.007087 0.018475 0.003629 0.007236
1000 0.008660 0.023100 0.004913 0.008972
2000 0.008596 0.023859 0.004916 0.009152
10000 0.008592 0.023808 0.004910 0.009274


copy

It implements a copy.

fct_onx = function_onnx_graph("copy")
print(onnx_simple_text_plot(fct_onx))

Out:

opset: domain='' version=15
input: name='X' type=dtype('float32') shape=()
Identity(X) -> Y
output: name='Y' type=dtype('float32') shape=()

benchmark

fct_numpy = lambda x: x.copy()

rows = benchmark(
    'copy', fct_onx, fct_numpy,
    numpy.random.randn(1000, 10).astype(numpy.float32))

all_rows.extend(rows)
piv = pandas.DataFrame(rows).pivot('dim', 'impl', 'average')
piv

Out:

  0%|          | 0/8 [00:00<?, ?it/s]
 12%|#2        | 1/8 [00:00<00:00,  9.82it/s]
 25%|##5       | 2/8 [00:00<00:00,  9.80it/s]
 38%|###7      | 3/8 [00:00<00:00,  9.28it/s]
 50%|#####     | 4/8 [00:00<00:00,  8.91it/s]
 62%|######2   | 5/8 [00:00<00:00,  8.43it/s]
 75%|#######5  | 6/8 [00:00<00:00,  7.76it/s]
 88%|########7 | 7/8 [00:00<00:00,  7.38it/s]
100%|##########| 8/8 [00:01<00:00,  7.18it/s]
100%|##########| 8/8 [00:01<00:00,  7.89it/s]
impl bind_run numpy run sess
dim
1 0.002721 0.000649 0.001098 0.002707
10 0.002737 0.000627 0.001103 0.002724
100 0.002757 0.000836 0.001135 0.002853
200 0.002794 0.000900 0.001172 0.002910
500 0.002894 0.001054 0.001247 0.003066
1000 0.002964 0.001348 0.001341 0.003318
2000 0.002971 0.001361 0.001360 0.003253
10000 0.002981 0.001324 0.001353 0.003254


Graph.

name = rows[0]['name']
ax = piv.plot(logx=True, logy=True)
ax.set_title(name + "\nlower is better")
copy lower is better

Out:

Text(0.5, 1.0, 'copy\nlower is better')

grad_loss_absolute_error

It implements Y = f(X1, X2) = \lVert X1 - X2 \rVert.

fct_onx = function_onnx_graph("grad_loss_absolute_error")
print(onnx_simple_text_plot(fct_onx))

Out:

opset: domain='' version=14
input: name='X1' type=dtype('float32') shape=(0, 0)
input: name='X2' type=dtype('float32') shape=(0, 0)
init: name='Re_Reshapecst' type=dtype('int64') shape=(1,) -- array([-1])
Sub(X1, X2) -> Su_C0
  Sign(Su_C0) -> Y_grad
  Abs(Su_C0) -> Ab_Y0
    ReduceSum(Ab_Y0) -> Re_reduced0
      Reshape(Re_reduced0, Re_Reshapecst) -> Y
output: name='Y' type=dtype('float32') shape=()
output: name='Y_grad' type=dtype('float32') shape=()

benchmark

fct_numpy = lambda x1, x2: (
    numpy.abs(x1 - x2).sum(), numpy.sign(x1 - x2))

rows = benchmark(
    'grad_loss_absolute_error', fct_onx, fct_numpy,
    numpy.random.randn(1000, 10).astype(numpy.float32),
    numpy.random.randn(1000, 10).astype(numpy.float32))

all_rows.extend(rows)
piv = pandas.DataFrame(rows).pivot('dim', 'impl', 'average')
piv

Out:

  0%|          | 0/8 [00:00<?, ?it/s]
 12%|#2        | 1/8 [00:00<00:02,  2.56it/s]
 25%|##5       | 2/8 [00:00<00:02,  2.55it/s]
 38%|###7      | 3/8 [00:01<00:02,  2.26it/s]
 50%|#####     | 4/8 [00:01<00:01,  2.04it/s]
 62%|######2   | 5/8 [00:02<00:01,  1.74it/s]
 75%|#######5  | 6/8 [00:03<00:01,  1.38it/s]
 88%|########7 | 7/8 [00:04<00:00,  1.21it/s]
100%|##########| 8/8 [00:05<00:00,  1.12it/s]
100%|##########| 8/8 [00:05<00:00,  1.41it/s]
impl bind_run numpy run sess
dim
1 0.004892 0.005378 0.002343 0.004455
10 0.004940 0.005400 0.002387 0.004474
100 0.005286 0.007336 0.002758 0.005016
200 0.005728 0.008320 0.003188 0.005466
500 0.006940 0.010785 0.004406 0.006787
1000 0.008976 0.015394 0.006398 0.008920
2000 0.009048 0.015723 0.006436 0.008967
10000 0.009012 0.015590 0.006411 0.008881


Graph.

name = rows[0]['name']
ax = piv.plot(logx=True, logy=True)
ax.set_title(name + "\nlower is better")
grad_loss_absolute_error lower is better

Out:

Text(0.5, 1.0, 'grad_loss_absolute_error\nlower is better')

grad_loss_square_error

It implements Y = f(X1, X2) = \lVert X1 - X2 \rVert^2.

fct_onx = function_onnx_graph("grad_loss_square_error")
print(onnx_simple_text_plot(fct_onx))

Out:

opset: domain='' version=14
input: name='X1' type=dtype('float32') shape=(0, 0)
input: name='X2' type=dtype('float32') shape=(0, 0)
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([1.], dtype=float32)
init: name='Re_Reshapecst' type=dtype('int64') shape=(1,) -- array([-1])
init: name='Mu_Mulcst1' type=dtype('float32') shape=(1,) -- array([-2.], dtype=float32)
Sub(X1, X2) -> Su_C0
  Mul(Su_C0, Mu_Mulcst1) -> Y_grad
ReduceSumSquare(Su_C0) -> Re_reduced0
  Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
    Reshape(Mu_C0, Re_Reshapecst) -> Y
output: name='Y' type=dtype('float32') shape=()
output: name='Y_grad' type=dtype('float32') shape=()

benchmark

fct_numpy = lambda x1, x2: (
    ((x1 - x2) ** 2).sum(), (x1 - x2) * (-2))

rows = benchmark(
    'grad_loss_square_error', fct_onx, fct_numpy,
    numpy.random.randn(1000, 10).astype(numpy.float32),
    numpy.random.randn(1000, 10).astype(numpy.float32))

all_rows.extend(rows)
piv = pandas.DataFrame(rows).pivot('dim', 'impl', 'average')
piv

Out:

  0%|          | 0/8 [00:00<?, ?it/s]
 12%|#2        | 1/8 [00:00<00:03,  2.17it/s]
 25%|##5       | 2/8 [00:00<00:02,  2.16it/s]
 38%|###7      | 3/8 [00:01<00:02,  1.97it/s]
 50%|#####     | 4/8 [00:02<00:02,  1.88it/s]
 62%|######2   | 5/8 [00:02<00:01,  1.74it/s]
 75%|#######5  | 6/8 [00:03<00:01,  1.54it/s]
 88%|########7 | 7/8 [00:04<00:00,  1.43it/s]
100%|##########| 8/8 [00:05<00:00,  1.37it/s]
100%|##########| 8/8 [00:05<00:00,  1.57it/s]
impl bind_run numpy run sess
dim
1 0.004992 0.006728 0.002403 0.004590
10 0.004987 0.006741 0.002423 0.004630
100 0.005123 0.008606 0.002535 0.004961
200 0.005240 0.008759 0.002665 0.005020
500 0.005583 0.010121 0.003018 0.005441
1000 0.006175 0.012511 0.003565 0.006238
2000 0.006152 0.012755 0.003566 0.006123
10000 0.006152 0.012709 0.003568 0.006112


Graph.

name = rows[0]['name']
ax = piv.plot(logx=True, logy=True)
ax.set_title(name + "\nlower is better")
grad_loss_square_error lower is better

Out:

Text(0.5, 1.0, 'grad_loss_square_error\nlower is better')

grad_loss_elastic_error

It implements Y = f(X1, X2) = \beta \lVert X1 - X2 \rVert +
\alpha \lVert X1 - X2 \rVert^2 or Y = f(X1, X2) = \beta \lVert w(X1 - X2) \rVert +
\alpha \lVert (\sqrt{w}(X1 - X2) \rVert^2 if weight_name is not None and its gradient. l1_weight is \beta and l2_weight is \alpha.

fct_onx = function_onnx_graph("grad_loss_elastic_error")
print(onnx_simple_text_plot(fct_onx))

Out:

opset: domain='' version=15
input: name='X1' type=dtype('float32') shape=(0, 0)
input: name='X2' type=dtype('float32') shape=(0, 0)
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.01], dtype=float32)
init: name='Re_Reshapecst' type=dtype('int64') shape=(1,) -- array([-1])
init: name='Mu_Mulcst3' type=dtype('float32') shape=(1,) -- array([-0.02], dtype=float32)
Identity(Mu_Mulcst) -> Mu_Mulcst1
Identity(Mu_Mulcst) -> Mu_Mulcst2
Sub(X1, X2) -> Su_C0
  Mul(Su_C0, Mu_Mulcst3) -> Mu_C05
Sign(Su_C0) -> Si_output0
  Mul(Si_output0, Mu_Mulcst2) -> Mu_C04
    Add(Mu_C04, Mu_C05) -> Ad_C02
      Identity(Ad_C02) -> Y_grad
  Mul(Su_C0, Su_C0) -> Mu_C03
  Mul(Mu_C03, Mu_Mulcst1) -> Mu_C02
Abs(Su_C0) -> Ab_Y0
  Mul(Ab_Y0, Mu_Mulcst) -> Mu_C0
    Add(Mu_C0, Mu_C02) -> Ad_C0
      ReduceSum(Ad_C0) -> Re_reduced0
        Reshape(Re_reduced0, Re_Reshapecst) -> Y
output: name='Y' type=dtype('float32') shape=()
output: name='Y_grad' type=dtype('float32') shape=()

benchmark

fct_numpy = lambda x1, x2: (
    numpy.abs(x1 - x2).sum() * 0.1 + ((x1 - x2) ** 2).sum() * 0.9,
    numpy.sign(x1 - x2) * 0.1 - 2 * 0.9 * (x1 - x2))

rows = benchmark(
    'grad_loss_elastic_error', fct_onx, fct_numpy,
    numpy.random.randn(1000, 10).astype(numpy.float32),
    numpy.random.randn(1000, 10).astype(numpy.float32))

all_rows.extend(rows)
piv = pandas.DataFrame(rows).pivot('dim', 'impl', 'average')
piv

Out:

  0%|          | 0/8 [00:00<?, ?it/s]
 12%|#2        | 1/8 [00:01<00:07,  1.12s/it]
 25%|##5       | 2/8 [00:02<00:06,  1.12s/it]
 38%|###7      | 3/8 [00:03<00:06,  1.23s/it]
 50%|#####     | 4/8 [00:05<00:05,  1.33s/it]
 62%|######2   | 5/8 [00:06<00:04,  1.48s/it]
 75%|#######5  | 6/8 [00:09<00:03,  1.74s/it]
 88%|########7 | 7/8 [00:11<00:01,  1.90s/it]
100%|##########| 8/8 [00:13<00:00,  2.00s/it]
100%|##########| 8/8 [00:13<00:00,  1.69s/it]
impl bind_run numpy run sess
dim
1 0.006727 0.018878 0.004245 0.006339
10 0.006791 0.018810 0.004312 0.006345
100 0.007568 0.023246 0.005018 0.007241
200 0.008349 0.024948 0.005816 0.008013
500 0.010447 0.029153 0.007927 0.010274
1000 0.014111 0.036601 0.011493 0.014316
2000 0.014031 0.036314 0.011409 0.014508
10000 0.014107 0.036410 0.011520 0.014311


Graph.

name = rows[0]['name']
ax = piv.plot(logx=True, logy=True)
ax.set_title(name + "\nlower is better")
grad_loss_elastic_error lower is better

Out:

Text(0.5, 1.0, 'grad_loss_elastic_error\nlower is better')

n_penalty_elastic_error

It implements Y = f(W) = \beta \lVert W \rVert +
\alpha \lVert W \rVert^2 l1_weight is \beta and l2_weight is \alpha. It does that for n_tensors and adds all of the results to an input loss.

fct_onx = function_onnx_graph("n_penalty_elastic_error")
print(onnx_simple_text_plot(fct_onx))

Out:

opset: domain='' version=15
input: name='loss' type=dtype('float32') shape=(1, 1)
input: name='W0' type=dtype('float32') shape=()
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.01], dtype=float32)
init: name='Re_Reshapecst' type=dtype('int64') shape=(1,) -- array([-1])
Abs(W0) -> Ab_Y0
  ReduceSum(Ab_Y0) -> Re_reduced0
    Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
ReduceSumSquare(W0) -> Re_reduced02
Identity(Mu_Mulcst) -> Mu_Mulcst1
  Mul(Re_reduced02, Mu_Mulcst1) -> Mu_C02
    Add(Mu_C0, Mu_C02) -> Ad_C01
      Add(loss, Ad_C01) -> Ad_C0
        Reshape(Ad_C0, Re_Reshapecst) -> Y
output: name='Y' type=dtype('float32') shape=(0,)

benchmark

fct_numpy = lambda loss, x: numpy.abs(x).sum() * 0.1 + ((x) ** 2).sum() * 0.9

rows = benchmark(
    'n_penalty_elastic_error', fct_onx, fct_numpy,
    numpy.array([[0.5]], dtype=numpy.float32),
    numpy.random.randn(1000, 10).astype(numpy.float32))

all_rows.extend(rows)
piv = pandas.DataFrame(rows).pivot('dim', 'impl', 'average')
piv

Out:

  0%|          | 0/8 [00:00<?, ?it/s]
 12%|#2        | 1/8 [00:00<00:04,  1.45it/s]
 25%|##5       | 2/8 [00:01<00:04,  1.45it/s]
 38%|###7      | 3/8 [00:02<00:03,  1.38it/s]
 50%|#####     | 4/8 [00:02<00:02,  1.34it/s]
 62%|######2   | 5/8 [00:03<00:02,  1.28it/s]
 75%|#######5  | 6/8 [00:04<00:01,  1.20it/s]
 88%|########7 | 7/8 [00:05<00:00,  1.16it/s]
100%|##########| 8/8 [00:06<00:00,  1.13it/s]
100%|##########| 8/8 [00:06<00:00,  1.22it/s]
impl bind_run numpy run sess
dim
1 0.004585 0.011367 0.002621 0.004443
10 0.004578 0.011417 0.002631 0.004412
100 0.004684 0.012726 0.002711 0.004551
200 0.004766 0.013061 0.002809 0.004637
500 0.005022 0.014097 0.003079 0.004913
1000 0.005446 0.015708 0.003480 0.005334
2000 0.005418 0.015659 0.003477 0.005255
10000 0.005448 0.015799 0.003486 0.005266


Graph.

name = rows[0]['name']
ax = piv.plot(logx=True, logy=True)
ax.set_title(name + "\nlower is better")
n_penalty_elastic_error lower is better

Out:

Text(0.5, 1.0, 'n_penalty_elastic_error\nlower is better')

update_penalty_elastic_error

It implements Y = f(W) = W - 2 \beta W - \alpha sign(W) l1 is \beta and l2 is \alpha.

fct_onx = function_onnx_graph("update_penalty_elastic_error")
print(onnx_simple_text_plot(fct_onx))

Out:

opset: domain='' version=14
input: name='X' type=dtype('float32') shape=()
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.9998], dtype=float32)
init: name='Mu_Mulcst1' type=dtype('float32') shape=(1,) -- array([1.e-04], dtype=float32)
Mul(X, Mu_Mulcst) -> Mu_C0
Sign(X) -> Si_output0
  Mul(Si_output0, Mu_Mulcst1) -> Mu_C02
  Sub(Mu_C0, Mu_C02) -> Y
output: name='Y' type=dtype('float32') shape=()

benchmark

fct_numpy = lambda x: numpy.sign(x) * 0.1 + (x * 0.9 * 2)

rows = benchmark(
    'update_penalty_elastic_error', fct_onx, fct_numpy,
    numpy.random.randn(1000, 10).astype(numpy.float32))

all_rows.extend(rows)
piv = pandas.DataFrame(rows).pivot('dim', 'impl', 'average')
piv

Out:

  0%|          | 0/8 [00:00<?, ?it/s]
 12%|#2        | 1/8 [00:00<00:03,  2.26it/s]
 25%|##5       | 2/8 [00:00<00:02,  2.26it/s]
 38%|###7      | 3/8 [00:01<00:02,  2.00it/s]
 50%|#####     | 4/8 [00:02<00:02,  1.81it/s]
 62%|######2   | 5/8 [00:02<00:01,  1.56it/s]
 75%|#######5  | 6/8 [00:03<00:01,  1.26it/s]
 88%|########7 | 7/8 [00:05<00:00,  1.13it/s]
100%|##########| 8/8 [00:06<00:00,  1.05it/s]
100%|##########| 8/8 [00:06<00:00,  1.31it/s]
impl bind_run numpy run sess
dim
1 0.003592 0.006939 0.001929 0.003570
10 0.003595 0.006948 0.001966 0.003547
100 0.004011 0.009130 0.002370 0.004365
200 0.004455 0.010068 0.002836 0.004784
500 0.005754 0.012711 0.004112 0.006313
1000 0.007908 0.016946 0.006276 0.008500
2000 0.007889 0.016904 0.006242 0.008482
10000 0.007912 0.017042 0.006264 0.008486


Graph.

name = rows[0]['name']
ax = piv.plot(logx=True, logy=True)
ax.set_title(name + "\nlower is better")
update_penalty_elastic_error lower is better

Out:

Text(0.5, 1.0, 'update_penalty_elastic_error\nlower is better')

grad_sigmoid_neg_log_loss_error

See _onnx_grad_sigmoid_neg_log_loss_error.

fct_onx = function_onnx_graph("grad_sigmoid_neg_log_loss_error")
print(onnx_simple_text_plot(fct_onx))

Out:

opset: domain='' version=15
input: name='X1' type=dtype('int64') shape=(0, 0)
input: name='X2' type=dtype('float32') shape=(0, 0)
init: name='Su_Subcst' type=dtype('float32') shape=(1,) -- array([1.], dtype=float32)
init: name='Cl_Clipcst' type=dtype('float32') shape=(1,) -- array([1.e-05], dtype=float32)
init: name='Cl_Clipcst1' type=dtype('float32') shape=(1,) -- array([0.99999], dtype=float32)
init: name='Re_Reshapecst' type=dtype('int64') shape=(1,) -- array([-1])
Cast(X1, to=1) -> Ca_output0
  Sub(Su_Subcst, Ca_output0) -> Su_C0
Sigmoid(X2) -> Si_Y0
  Clip(Si_Y0, Cl_Clipcst, Cl_Clipcst1) -> Cl_output0
    Log(Cl_output0) -> Lo_output02
  Mul(Ca_output0, Lo_output02) -> Mu_C02
Identity(Su_Subcst) -> Su_Subcst1
  Sub(Su_Subcst1, Cl_output0) -> Su_C02
    Log(Su_C02) -> Lo_output0
    Mul(Su_C0, Lo_output0) -> Mu_C0
    Add(Mu_C0, Mu_C02) -> Ad_C0
      Neg(Ad_C0) -> Ne_Y0
        ReduceSum(Ne_Y0) -> Re_reduced0
          Reshape(Re_reduced0, Re_Reshapecst) -> Y
  Sub(Cl_output0, Ca_output0) -> Y_grad
output: name='Y' type=dtype('float32') shape=()
output: name='Y_grad' type=dtype('float32') shape=()

benchmark

def loss(x1, x2, eps=1e-5):
    pr = expit(x2)
    cl = numpy.clip(pr, eps, 1 - eps)
    lo = - (1 - x1) * numpy.log(1 - cl) - x1 * numpy.log(cl)
    return lo


fct_numpy = lambda x1, x2: (loss(x1, x2).mean(), expit(x2) - x1)

rows = benchmark(
    'grad_sigmoid_neg_log_loss_error', fct_onx, fct_numpy,
    (numpy.random.randn(1000, 1) > 0).astype(numpy.int64),
    numpy.random.randn(1000, 10).astype(numpy.float32))

all_rows.extend(rows)
piv = pandas.DataFrame(rows).pivot('dim', 'impl', 'average')
piv

Out:

  0%|          | 0/8 [00:00<?, ?it/s]
 12%|#2        | 1/8 [00:01<00:12,  1.81s/it]
 25%|##5       | 2/8 [00:03<00:11,  1.87s/it]
 38%|###7      | 3/8 [00:06<00:12,  2.41s/it]
 50%|#####     | 4/8 [00:10<00:12,  3.13s/it]
 62%|######2   | 5/8 [00:18<00:14,  4.74s/it]
 75%|#######5  | 6/8 [00:32<00:15,  7.71s/it]
 88%|########7 | 7/8 [00:45<00:09,  9.60s/it]
100%|##########| 8/8 [00:59<00:00, 10.84s/it]
100%|##########| 8/8 [00:59<00:00,  7.38s/it]
impl bind_run numpy run sess
dim
1 0.007246 0.032292 0.004600 0.006824
10 0.007825 0.034267 0.005167 0.007362
100 0.011676 0.054584 0.009018 0.011367
200 0.015957 0.075253 0.013220 0.015670
500 0.028750 0.135132 0.025942 0.028400
1000 0.049898 0.239883 0.046879 0.050397
2000 0.050270 0.240150 0.047119 0.049828
10000 0.050023 0.240576 0.046890 0.049994


Graph.

name = rows[0]['name']
ax = piv.plot(logx=True, logy=True)
ax.set_title(name + "\nlower is better")
grad_sigmoid_neg_log_loss_error lower is better

Out:

Text(0.5, 1.0, 'grad_sigmoid_neg_log_loss_error\nlower is better')

Results

df = pandas.DataFrame(all_rows)
df
average deviation min_exec max_exec repeat number ttime context_size name impl dim
0 0.002472 0.000016 0.002451 0.002533 50 100 0.123616 64 axpy numpy 1
1 0.003834 0.000048 0.003807 0.003975 10 50 0.038343 64 axpy sess 1
2 0.004020 0.000032 0.004004 0.004114 10 50 0.040200 64 axpy bind_run 1
3 0.001767 0.000022 0.001750 0.001827 10 50 0.017672 64 axpy run 1
4 0.002464 0.000014 0.002445 0.002531 50 100 0.123204 64 axpy numpy 10
... ... ... ... ... ... ... ... ... ... ... ...
315 0.047119 0.000050 0.047072 0.047253 10 50 0.471192 64 grad_sigmoid_neg_log_loss_error run 2000
316 0.240576 0.000649 0.239843 0.242718 50 100 12.028781 64 grad_sigmoid_neg_log_loss_error numpy 10000
317 0.049994 0.000101 0.049930 0.050287 10 50 0.499941 64 grad_sigmoid_neg_log_loss_error sess 10000
318 0.050023 0.000121 0.049916 0.050353 10 50 0.500229 64 grad_sigmoid_neg_log_loss_error bind_run 10000
319 0.046890 0.000041 0.046851 0.047002 10 50 0.468904 64 grad_sigmoid_neg_log_loss_error run 10000

320 rows × 11 columns



Pivot

piv = pandas.pivot_table(
    df, index=['name', 'impl'], columns='dim', values='average')
piv
print(piv)

Out:

dim                                          1      ...     10000
name                            impl                ...
axpy                            bind_run  0.004020  ...  0.005004
                                numpy     0.002472  ...  0.004765
                                run       0.001767  ...  0.002741
                                sess      0.003834  ...  0.005211
axpyw                           bind_run  0.005410  ...  0.007290
                                numpy     0.008006  ...  0.016235
                                run       0.002002  ...  0.003779
                                sess      0.004804  ...  0.007562
axpyw2                          bind_run  0.005715  ...  0.008592
                                numpy     0.011349  ...  0.023808
                                run       0.002276  ...  0.004910
                                sess      0.005124  ...  0.009274
copy                            bind_run  0.002721  ...  0.002981
                                numpy     0.000649  ...  0.001324
                                run       0.001098  ...  0.001353
                                sess      0.002707  ...  0.003254
grad_loss_absolute_error        bind_run  0.004892  ...  0.009012
                                numpy     0.005378  ...  0.015590
                                run       0.002343  ...  0.006411
                                sess      0.004455  ...  0.008881
grad_loss_elastic_error         bind_run  0.006727  ...  0.014107
                                numpy     0.018878  ...  0.036410
                                run       0.004245  ...  0.011520
                                sess      0.006339  ...  0.014311
grad_loss_square_error          bind_run  0.004992  ...  0.006152
                                numpy     0.006728  ...  0.012709
                                run       0.002403  ...  0.003568
                                sess      0.004590  ...  0.006112
grad_sigmoid_neg_log_loss_error bind_run  0.007246  ...  0.050023
                                numpy     0.032292  ...  0.240576
                                run       0.004600  ...  0.046890
                                sess      0.006824  ...  0.049994
n_penalty_elastic_error         bind_run  0.004585  ...  0.005448
                                numpy     0.011367  ...  0.015799
                                run       0.002621  ...  0.003486
                                sess      0.004443  ...  0.005266
update_penalty_elastic_error    bind_run  0.003592  ...  0.007912
                                numpy     0.006939  ...  0.017042
                                run       0.001929  ...  0.006264
                                sess      0.003570  ...  0.008486

[40 rows x 8 columns]

Graph.

fig, ax = None, None


for i, name in enumerate(sorted(set(df['name']))):
    if fig is None:
        fig, ax = plt.subplots(2, 2, figsize=(8, 12), sharex=True)
    x, y = (i % 4) // 2, (i % 4) % 2
    piv = df[df.name == name].pivot('dim', 'impl', 'average')
    piv.plot(ax=ax[x, y], logx=True, logy=True)
    ax[x, y].set_title(name)
    ax[x, y].xaxis.set_label_text("")
    if i % 4 == 3:
        fig.suptitle("lower is better")
        fig.tight_layout()
        fig, ax = None, None


if fig is not None:
    fig.suptitle("lower is better")
    fig.tight_layout()


# plt.show()
  • lower is better, axpy, axpyw, axpyw2, copy
  • lower is better, grad_loss_absolute_error, grad_loss_elastic_error, grad_loss_square_error, grad_sigmoid_neg_log_loss_error
  • lower is better, n_penalty_elastic_error, update_penalty_elastic_error

Total running time of the script: ( 2 minutes 28.826 seconds)

Gallery generated by Sphinx-Gallery