Compares implementations of Einsum#

This example compares different equations for function numpy.einsum. It compares numpy implementation to a custom implementation, onnxruntime implementation and opt-einsum optimisation. If available, tensorflow and pytorch are included as well. The custom implementation does not do any transpose. It uses parallelisation and SIMD optimization when the summation happens on the last axis of both matrices. It only implements matrix multiplication. We also measure the improvment made with function einsum.

Available optimisation#

The code shows which optimisation is used for the custom implementation, AVX or SSE and the number of available processors, equal to the default number of used threads to parallelize.

import numpy
import pandas
import matplotlib.pyplot as plt
from onnxruntime import InferenceSession
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx.algebra.onnx_ops import OnnxEinsum
from cpyquickhelper.numbers import measure_time
from tqdm import tqdm
from opt_einsum import contract
from mlprodict.testing.experimental_c_impl.experimental_c import (
    custom_einsum_float, code_optimisation)
from mlprodict.testing.einsum.einsum_fct import _einsum
print(code_optimisation())

Out:

AVX-omp=8

Einsum: common code#

try:
    from tensorflow import einsum as tf_einsum, convert_to_tensor
except ImportError:
    tf_einsum = None
try:
    from torch import einsum as torch_einsum, from_numpy
except ImportError:
    torch_einsum = None


def build_ort_einsum(equation, op_version=14):  # opset=13, 14, ...
    node = OnnxEinsum('x', 'y', equation=equation,
                      op_version=op_version,
                      output_names=['z'])
    onx = node.to_onnx(inputs=[('x', FloatTensorType()),
                               ('y', FloatTensorType())],
                       target_opset=op_version)
    sess = InferenceSession(onx.SerializeToString())
    return lambda x, y: sess.run(None, {'x': x, 'y': y})


def build_ort_decomposed(equation, op_version=14):  # opset=13, 14, ...
    cache = _einsum(equation, numpy.float32, opset=op_version,
                    optimize=True, verbose=True, runtime="python")
    if not hasattr(cache, 'onnx_'):
        cache.build()
    sess = InferenceSession(cache.onnx_.SerializeToString())
    return lambda x, y: sess.run(None, {'X0': x, 'X1': y})


def loop_einsum_eq(fct, equation, xs, ys):
    for x, y in zip(xs, ys):
        fct(equation, x, y)


def loop_einsum_eq_th(fct, equation, xs, ys):
    for x, y in zip(xs, ys):
        fct(equation, x, y, nthread=-1)


def loop_einsum(fct, xs, ys):
    for x, y in zip(xs, ys):
        fct(x, y)


def custom_einsum_float_tr(eq, x, y):
    if eq == "bshn,bthn->bnts":
        x = x.transpose((0, 1, 3, 2))
        y = y.transpose((0, 1, 3, 2))
        return custom_einsum_float("bsnh,btnh->bnts", x, y, nthread=-1)
    if eq == "bhsn,bhtn->bnts":
        x = x.transpose((0, 2, 3, 1))
        y = y.transpose((0, 2, 3, 1))
        return custom_einsum_float("bsnh,btnh->bnts", x, y, nthread=-1)
    return custom_einsum_float(eq, x, y, nthread=-1)


def benchmark_equation(equation):
    # equations
    ort_einsum = build_ort_einsum(equation)
    ort_einsum_decomposed = build_ort_decomposed(equation)
    res = []
    for dim in tqdm([8, 16, 32, 64, 100, 128, 200,
                     256, 500, 512]):
        xs = [numpy.random.rand(2, dim, 12, 64).astype(numpy.float32)
              for _ in range(5)]
        ys = [numpy.random.rand(2, dim, 12, 64).astype(numpy.float32)
              for _ in range(5)]

        # numpy
        ctx = dict(equation=equation, xs=xs, ys=ys, einsum=numpy.einsum,
                   loop_einsum=loop_einsum, loop_einsum_eq=loop_einsum_eq,
                   loop_einsum_eq_th=loop_einsum_eq_th)
        obs = measure_time(
            "loop_einsum_eq(einsum, equation, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'numpy.einsum'
        res.append(obs)

        # opt-einsum
        ctx['einsum'] = contract
        obs = measure_time(
            "loop_einsum_eq(einsum, equation, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'opt-einsum'
        res.append(obs)

        # onnxruntime
        ctx['einsum'] = ort_einsum
        obs = measure_time(
            "loop_einsum(einsum, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'ort_einsum'
        res.append(obs)

        # onnxruntime decomposed
        ctx['einsum'] = ort_einsum_decomposed
        obs = measure_time(
            "loop_einsum(einsum, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'ort_dec'
        res.append(obs)

        # custom implementation
        ctx['einsum'] = custom_einsum_float
        obs = measure_time(
            "loop_einsum_eq_th(einsum, equation, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'c_einsum'
        res.append(obs)

        # transpose + custom implementation
        ctx['einsum'] = custom_einsum_float_tr
        obs = measure_time(
            "loop_einsum_eq(einsum, equation, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'c_einsum_tr'
        res.append(obs)

        if tf_einsum is not None:
            # tensorflow
            ctx['einsum'] = tf_einsum
            ctx['xs'] = [convert_to_tensor(x) for x in xs]
            ctx['ys'] = [convert_to_tensor(y) for y in ys]
            obs = measure_time(
                "loop_einsum_eq(einsum, equation, xs, ys)",
                div_by_number=True, context=ctx, repeat=5, number=1)
            obs['dim'] = dim
            obs['fct'] = 'tf_einsum'
            res.append(obs)

        if torch_einsum is not None:
            # torch
            ctx['einsum'] = torch_einsum
            ctx['xs'] = [from_numpy(x) for x in xs]
            ctx['ys'] = [from_numpy(y) for y in ys]
            obs = measure_time(
                "loop_einsum_eq(einsum, equation, xs, ys)",
                div_by_number=True, context=ctx, repeat=5, number=1)
            obs['dim'] = dim
            obs['fct'] = 'torch_einsum'
            res.append(obs)

    # Dataframes
    df = pandas.DataFrame(res)
    piv = df.pivot('dim', 'fct', 'average')

    rs = piv.copy()
    rs['c_einsum'] = rs['numpy.einsum'] / rs['c_einsum']
    rs['ort_einsum'] = rs['numpy.einsum'] / rs['ort_einsum']
    rs['ort_dec'] = rs['numpy.einsum'] / rs['ort_dec']
    rs['opt-einsum'] = rs['numpy.einsum'] / rs['opt-einsum']
    if 'c_einsum_tr' in rs.columns:
        rs['c_einsum_tr'] = rs['numpy.einsum'] / rs['c_einsum_tr']
    if 'tf_einsum' in rs.columns:
        rs['tf_einsum'] = rs['numpy.einsum'] / rs['tf_einsum']
    if 'torch_einsum' in rs.columns:
        rs['torch_einsum'] = rs['numpy.einsum'] / rs['torch_einsum']
    rs['numpy.einsum'] = 1.

    # Graphs.
    fig, ax = plt.subplots(1, 2, figsize=(14, 5))
    piv.plot(logx=True, logy=True, ax=ax[0],
             title="Einsum benchmark\n%s -- (2, N, 12, 64)"
                   " lower better" % equation)
    ax[0].legend(prop={"size": 9})
    rs.plot(logx=True, logy=True, ax=ax[1],
            title="Einsum Speedup, baseline=numpy\n%s -- (2, N, 12, 64)"
                  " higher better" % equation)
    ax[1].plot([min(rs.index), max(rs.index)], [0.5, 0.5], 'g--')
    ax[1].plot([min(rs.index), max(rs.index)], [2., 2.], 'g--')
    ax[1].legend(prop={"size": 9})

    return df, rs, ax

First equation: bsnh,btnh->bnts#

The decomposition of this equation without einsum function gives the following.

%0 0 input 0\nbsnh\n[ 0 3 2 1 -1] 139925674708896 id\nNone 0->139925674708896 139925674708704 expand_dims\naxes=((4, 4),)None 139925674708896->139925674708704 139925674708320 transpose - I0\nperm=(0, 2, 1, 4, 3)None 139925674708704->139925674708320 1 input 1\nbtnh\n[ 0 3 2 -1 1] 139925674708272 id\nNone 1->139925674708272 139925674708368 expand_dims\naxes=((3, 3),)None 139925674708272->139925674708368 139925674707984 transpose\nperm=(0, 2, 3, 1, 4)None 139925674708368->139925674707984 139925674707120 batch_dot\nbatch_axes=(0, 1) keep_axes=None left=(0, 1, 2) ndim=5 right=(0, 1, 3) sum_axes=(4,)None 139925674708464 transpose - I1\nperm=(0, 4, 1, 3, 2)None 139925674707120->139925674708464 139925674708320->139925674707120 139925674707984->139925674707120 139925674707504 squeeze\naxes=(1,)None 139925674708080 id - I-1\nNone 139925674707504->139925674708080 139925674708464->139925674707504
dfs = []
equation = "bsnh,btnh->bnts"
df, piv, ax = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
dfs.append(df)
Einsum benchmark bsnh,btnh->bnts -- (2, N, 12, 64) lower better, Einsum Speedup, baseline=numpy bsnh,btnh->bnts -- (2, N, 12, 64) higher better

Out:

  0%|          | 0/121 [00:00<?, ?it/s]
0.032 rtbest='bsnh,btnh->bnts':   0%|          | 0/121 [00:00<?, ?it/s]
0.032 rtbest='bsnh,btnh->bnts':   1%|          | 1/121 [00:00<00:17,  6.80it/s]
0.032 rtbest='bsnh,btnh->bnts':   2%|2         | 3/121 [00:00<00:09, 12.16it/s]
0.032 rtbest='bnsh,btsh->bstn':   2%|2         | 3/121 [00:00<00:09, 12.16it/s]
0.032 rtbest='bnsh,btsh->bstn':   4%|4         | 5/121 [00:00<00:08, 14.10it/s]
0.032 rtbest='bnsh,btsh->bstn':   6%|5         | 7/121 [00:00<00:07, 15.09it/s]
0.032 rtbest='bnsh,btsh->bstn':   7%|7         | 9/121 [00:00<00:07, 15.68it/s]
0.032 rtbest='bnsh,btsh->bstn':   9%|9         | 11/121 [00:00<00:06, 16.10it/s]
0.032 rtbest='bnsh,btsh->bstn':  11%|#         | 13/121 [00:00<00:06, 16.01it/s]
0.032 rtbest='bnsh,btsh->bstn':  12%|#2        | 15/121 [00:00<00:06, 16.21it/s]
0.032 rtbest='bnsh,btsh->bstn':  14%|#4        | 17/121 [00:01<00:06, 16.39it/s]
0.032 rtbest='bhts,bnts->btnh':  14%|#4        | 17/121 [00:01<00:06, 16.39it/s]
0.032 rtbest='bhts,bnts->btnh':  16%|#5        | 19/121 [00:01<00:06, 16.48it/s]
0.032 rtbest='bhts,bnts->btnh':  17%|#7        | 21/121 [00:01<00:06, 16.55it/s]
0.032 rtbest='bhts,bnts->btnh':  19%|#9        | 23/121 [00:01<00:05, 16.64it/s]
0.032 rtbest='bhts,bnts->btnh':  21%|##        | 25/121 [00:01<00:05, 16.65it/s]
0.032 rtbest='bhts,bnts->btnh':  22%|##2       | 27/121 [00:01<00:05, 16.34it/s]
0.032 rtbest='bhts,bnts->btnh':  24%|##3       | 29/121 [00:01<00:05, 16.45it/s]
0.032 rtbest='bhts,bnts->btnh':  26%|##5       | 31/121 [00:01<00:05, 16.51it/s]
0.032 rtbest='nshb,nthb->nhts':  26%|##5       | 31/121 [00:02<00:05, 16.51it/s]
0.032 rtbest='nshb,nthb->nhts':  27%|##7       | 33/121 [00:02<00:05, 16.55it/s]
0.032 rtbest='snhb,sthb->shtn':  27%|##7       | 33/121 [00:02<00:05, 16.55it/s]
0.032 rtbest='tnhb,tshb->thsn':  27%|##7       | 33/121 [00:02<00:05, 16.55it/s]
0.032 rtbest='tnhb,tshb->thsn':  29%|##8       | 35/121 [00:02<00:05, 16.54it/s]
0.032 rtbest='sthb,snhb->shnt':  29%|##8       | 35/121 [00:02<00:05, 16.54it/s]
0.032 rtbest='sthb,snhb->shnt':  31%|###       | 37/121 [00:02<00:05, 16.60it/s]
0.032 rtbest='sthb,snhb->shnt':  32%|###2      | 39/121 [00:02<00:04, 16.63it/s]
0.032 rtbest='sthb,snhb->shnt':  34%|###3      | 41/121 [00:02<00:04, 16.33it/s]
0.032 rtbest='sthb,snhb->shnt':  36%|###5      | 43/121 [00:02<00:04, 16.44it/s]
0.032 rtbest='sthb,snhb->shnt':  37%|###7      | 45/121 [00:02<00:04, 16.51it/s]
0.032 rtbest='sthb,snhb->shnt':  39%|###8      | 47/121 [00:02<00:04, 16.59it/s]
0.032 rtbest='tnsb,thsb->tshn':  39%|###8      | 47/121 [00:03<00:04, 16.59it/s]
0.032 rtbest='tnsb,thsb->tshn':  40%|####      | 49/121 [00:03<00:04, 16.61it/s]
0.032 rtbest='tnsb,thsb->tshn':  42%|####2     | 51/121 [00:03<00:04, 16.69it/s]
0.032 rtbest='hnbt,hsbt->hbsn':  42%|####2     | 51/121 [00:03<00:04, 16.69it/s]
0.032 rtbest='hnbt,hsbt->hbsn':  44%|####3     | 53/121 [00:03<00:04, 16.76it/s]
0.032 rtbest='hnbt,hsbt->hbsn':  45%|####5     | 55/121 [00:03<00:04, 16.43it/s]
0.032 rtbest='hnbt,hsbt->hbsn':  47%|####7     | 57/121 [00:03<00:03, 16.52it/s]
0.032 rtbest='hnbt,hsbt->hbsn':  49%|####8     | 59/121 [00:03<00:03, 16.60it/s]
0.032 rtbest='hnbt,hsbt->hbsn':  50%|#####     | 61/121 [00:03<00:03, 16.63it/s]
0.032 rtbest='hnbt,hsbt->hbsn':  52%|#####2    | 63/121 [00:03<00:03, 16.68it/s]
0.032 rtbest='hnbt,hsbt->hbsn':  54%|#####3    | 65/121 [00:04<00:03, 16.73it/s]
0.032 rtbest='hnbt,hsbt->hbsn':  55%|#####5    | 67/121 [00:04<00:03, 16.76it/s]
0.032 rtbest='ntbs,nhbs->nbht':  55%|#####5    | 67/121 [00:04<00:03, 16.76it/s]
0.032 rtbest='ntbs,nhbs->nbht':  57%|#####7    | 69/121 [00:04<00:03, 16.35it/s]
0.032 rtbest='ntbs,nhbs->nbht':  59%|#####8    | 71/121 [00:04<00:03, 16.45it/s]
0.032 rtbest='ntbs,nhbs->nbht':  60%|######    | 73/121 [00:04<00:02, 16.54it/s]
0.032 rtbest='ntbs,nhbs->nbht':  62%|######1   | 75/121 [00:04<00:02, 16.61it/s]
0.032 rtbest='ntbs,nhbs->nbht':  64%|######3   | 77/121 [00:04<00:02, 16.68it/s]
0.032 rtbest='ntbs,nhbs->nbht':  65%|######5   | 79/121 [00:04<00:02, 16.70it/s]
0.032 rtbest='ntbs,nhbs->nbht':  67%|######6   | 81/121 [00:04<00:02, 16.72it/s]
0.032 rtbest='ntbs,nhbs->nbht':  69%|######8   | 83/121 [00:05<00:02, 16.35it/s]
0.032 rtbest='ntbs,nhbs->nbht':  70%|#######   | 85/121 [00:05<00:02, 16.41it/s]
0.032 rtbest='ntbs,nhbs->nbht':  72%|#######1  | 87/121 [00:05<00:02, 16.50it/s]
0.032 rtbest='ntbs,nhbs->nbht':  74%|#######3  | 89/121 [00:05<00:01, 16.58it/s]
0.032 rtbest='ntbs,nhbs->nbht':  75%|#######5  | 91/121 [00:05<00:01, 16.66it/s]
0.032 rtbest='ntbs,nhbs->nbht':  77%|#######6  | 93/121 [00:05<00:01, 16.68it/s]
0.032 rtbest='ntbs,nhbs->nbht':  79%|#######8  | 95/121 [00:05<00:01, 16.71it/s]
0.032 rtbest='ntbs,nhbs->nbht':  80%|########  | 97/121 [00:05<00:01, 16.34it/s]
0.032 rtbest='ntbs,nhbs->nbht':  82%|########1 | 99/121 [00:06<00:01, 16.43it/s]
0.032 rtbest='ntbs,nhbs->nbht':  83%|########3 | 101/121 [00:06<00:01, 16.50it/s]
0.032 rtbest='ntbs,nhbs->nbht':  85%|########5 | 103/121 [00:06<00:01, 16.58it/s]
0.032 rtbest='ntbs,nhbs->nbht':  87%|########6 | 105/121 [00:06<00:00, 16.64it/s]
0.032 rtbest='ntbs,nhbs->nbht':  88%|########8 | 107/121 [00:06<00:00, 16.66it/s]
0.032 rtbest='ntbs,nhbs->nbht':  90%|######### | 109/121 [00:06<00:00, 16.69it/s]
0.032 rtbest='ntbs,nhbs->nbht':  92%|#########1| 111/121 [00:06<00:00, 16.33it/s]
0.032 rtbest='ntbs,nhbs->nbht':  93%|#########3| 113/121 [00:06<00:00, 16.40it/s]
0.032 rtbest='ntbs,nhbs->nbht':  95%|#########5| 115/121 [00:07<00:00, 16.50it/s]
0.032 rtbest='ntbs,nhbs->nbht':  97%|#########6| 117/121 [00:07<00:00, 16.57it/s]
0.032 rtbest='ntbs,nhbs->nbht':  98%|#########8| 119/121 [00:07<00:00, 16.60it/s]
0.032 rtbest='ntbs,nhbs->nbht': 100%|##########| 121/121 [00:07<00:00, 16.62it/s]
0.032 rtbest='ntbs,nhbs->nbht': 100%|##########| 121/121 [00:07<00:00, 16.38it/s]

  0%|          | 0/10 [00:00<?, ?it/s]
 10%|#         | 1/10 [00:01<00:12,  1.35s/it]
 20%|##        | 2/10 [00:01<00:06,  1.16it/s]
 30%|###       | 3/10 [00:03<00:07,  1.08s/it]
 40%|####      | 4/10 [00:05<00:08,  1.39s/it]
 50%|#####     | 5/10 [00:08<00:09,  1.97s/it]
 60%|######    | 6/10 [00:12<00:10,  2.67s/it]
 70%|#######   | 7/10 [00:20<00:13,  4.51s/it]
 80%|########  | 8/10 [00:33<00:14,  7.16s/it]
 90%|######### | 9/10 [01:22<00:20, 20.33s/it]
100%|##########| 10/10 [02:18<00:00, 31.28s/it]
100%|##########| 10/10 [02:18<00:00, 13.83s/it]

Second equation: bshn,bthn->bnts#

The summation does not happen on the last axis but on the previous one. Is it worth transposing before doing the summation… The decomposition of this equation without einsum function gives the following.

%0 0 input 0\nbshn\n[ 0 2 3 1 -1] 139925674021504 id\nNone 0->139925674021504 139925674024912 expand_dims\naxes=((4, 4),)None 139925674021504->139925674024912 139925674024096 transpose - I0\nperm=(0, 3, 1, 4, 2)None 139925674024912->139925674024096 1 input 1\nbthn\n[ 0 2 3 -1 1] 139925674024288 id\nNone 1->139925674024288 139925674022944 expand_dims\naxes=((3, 3),)None 139925674024288->139925674022944 139925674024864 transpose\nperm=(0, 4, 3, 1, 2)None 139925674022944->139925674024864 139925674708032 batch_dot\nbatch_axes=(0, 1) keep_axes=None left=(0, 1, 2) ndim=5 right=(0, 1, 3) sum_axes=(4,)None 139925674022608 transpose - I1\nperm=(0, 4, 1, 3, 2)None 139925674708032->139925674022608 139925674024096->139925674708032 139925674024864->139925674708032 139925674022320 squeeze\naxes=(1,)None 139925674020976 id - I-1\nNone 139925674022320->139925674020976 139925674022608->139925674022320
equation = "bshn,bthn->bnts"
df, piv, ax = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
dfs.append(df)
Einsum benchmark bshn,bthn->bnts -- (2, N, 12, 64) lower better, Einsum Speedup, baseline=numpy bshn,bthn->bnts -- (2, N, 12, 64) higher better

Out:

  0%|          | 0/121 [00:00<?, ?it/s]
0.031 rtbest='bshn,bthn->bnts':   0%|          | 0/121 [00:00<?, ?it/s]
0.031 rtbest='bshn,bthn->bnts':   0%|          | 0/121 [00:00<?, ?it/s]
0.031 rtbest='bshn,bthn->bnts':   2%|1         | 2/121 [00:00<00:07, 16.28it/s]
0.031 rtbest='bthn,bshn->bnst':   2%|1         | 2/121 [00:00<00:07, 16.28it/s]
0.031 rtbest='bthn,bshn->bnst':   3%|3         | 4/121 [00:00<00:07, 16.69it/s]
0.031 rtbest='bthn,bshn->bnst':   5%|4         | 6/121 [00:00<00:06, 16.87it/s]
0.031 rtbest='bsnh,btnh->bhts':   5%|4         | 6/121 [00:00<00:06, 16.87it/s]
0.031 rtbest='bsnh,btnh->bhts':   7%|6         | 8/121 [00:00<00:06, 16.94it/s]
0.031 rtbest='bnsh,btsh->bhtn':   7%|6         | 8/121 [00:00<00:06, 16.94it/s]
0.031 rtbest='bnsh,btsh->bhtn':   8%|8         | 10/121 [00:00<00:06, 17.02it/s]
0.031 rtbest='bnsh,btsh->bhtn':  10%|9         | 12/121 [00:00<00:06, 17.21it/s]
0.031 rtbest='bnsh,btsh->bhtn':  12%|#1        | 14/121 [00:00<00:06, 16.93it/s]
0.031 rtbest='bnsh,btsh->bhtn':  13%|#3        | 16/121 [00:00<00:06, 17.01it/s]
0.031 rtbest='bnsh,btsh->bhtn':  15%|#4        | 18/121 [00:01<00:06, 17.06it/s]
0.031 rtbest='bnsh,btsh->bhtn':  17%|#6        | 20/121 [00:01<00:05, 17.09it/s]
0.031 rtbest='bnsh,btsh->bhtn':  18%|#8        | 22/121 [00:01<00:05, 17.13it/s]
0.031 rtbest='bnsh,btsh->bhtn':  20%|#9        | 24/121 [00:01<00:05, 17.20it/s]
0.031 rtbest='bnsh,btsh->bhtn':  21%|##1       | 26/121 [00:01<00:05, 17.16it/s]
0.031 rtbest='bnsh,btsh->bhtn':  23%|##3       | 28/121 [00:01<00:05, 16.86it/s]
0.031 rtbest='bnsh,btsh->bhtn':  25%|##4       | 30/121 [00:01<00:05, 16.90it/s]
0.031 rtbest='bnsh,btsh->bhtn':  26%|##6       | 32/121 [00:01<00:05, 16.89it/s]
0.031 rtbest='bnsh,btsh->bhtn':  28%|##8       | 34/121 [00:02<00:05, 16.84it/s]
0.031 rtbest='bnsh,btsh->bhtn':  30%|##9       | 36/121 [00:02<00:05, 16.83it/s]
0.031 rtbest='bnsh,btsh->bhtn':  31%|###1      | 38/121 [00:02<00:04, 16.85it/s]
0.031 rtbest='bnsh,btsh->bhtn':  33%|###3      | 40/121 [00:02<00:04, 16.89it/s]
0.031 rtbest='bnsh,btsh->bhtn':  35%|###4      | 42/121 [00:02<00:04, 16.63it/s]
0.031 rtbest='bnsh,btsh->bhtn':  36%|###6      | 44/121 [00:02<00:04, 16.69it/s]
0.031 rtbest='bnsh,btsh->bhtn':  38%|###8      | 46/121 [00:02<00:04, 16.75it/s]
0.031 rtbest='bnsh,btsh->bhtn':  40%|###9      | 48/121 [00:02<00:04, 16.79it/s]
0.031 rtbest='bnsh,btsh->bhtn':  41%|####1     | 50/121 [00:02<00:04, 16.78it/s]
0.031 rtbest='bnsh,btsh->bhtn':  43%|####2     | 52/121 [00:03<00:04, 16.74it/s]
0.031 rtbest='bnsh,btsh->bhtn':  45%|####4     | 54/121 [00:03<00:03, 16.81it/s]
0.031 rtbest='bnsh,btsh->bhtn':  46%|####6     | 56/121 [00:03<00:03, 16.52it/s]
0.031 rtbest='bnsh,btsh->bhtn':  48%|####7     | 58/121 [00:03<00:03, 16.57it/s]
0.031 rtbest='bnsh,btsh->bhtn':  50%|####9     | 60/121 [00:03<00:03, 16.59it/s]
0.031 rtbest='bnsh,btsh->bhtn':  51%|#####1    | 62/121 [00:03<00:03, 16.63it/s]
0.031 rtbest='bnsh,btsh->bhtn':  53%|#####2    | 64/121 [00:03<00:03, 16.63it/s]
0.031 rtbest='bnsh,btsh->bhtn':  55%|#####4    | 66/121 [00:03<00:03, 16.66it/s]
0.031 rtbest='bnsh,btsh->bhtn':  56%|#####6    | 68/121 [00:04<00:03, 16.69it/s]
0.031 rtbest='bnsh,btsh->bhtn':  58%|#####7    | 70/121 [00:04<00:03, 16.72it/s]
0.031 rtbest='bnsh,btsh->bhtn':  60%|#####9    | 72/121 [00:04<00:02, 16.39it/s]
0.031 rtbest='bnsh,btsh->bhtn':  61%|######1   | 74/121 [00:04<00:02, 16.52it/s]
0.031 rtbest='bnsh,btsh->bhtn':  63%|######2   | 76/121 [00:04<00:02, 16.65it/s]
0.031 rtbest='bnsh,btsh->bhtn':  64%|######4   | 78/121 [00:04<00:02, 16.79it/s]
0.031 rtbest='bnsh,btsh->bhtn':  66%|######6   | 80/121 [00:04<00:02, 16.90it/s]
0.031 rtbest='bnsh,btsh->bhtn':  68%|######7   | 82/121 [00:04<00:02, 16.89it/s]
0.031 rtbest='bnsh,btsh->bhtn':  69%|######9   | 84/121 [00:04<00:02, 16.93it/s]
0.031 rtbest='bnsh,btsh->bhtn':  71%|#######1  | 86/121 [00:05<00:02, 16.51it/s]
0.031 rtbest='bnsh,btsh->bhtn':  73%|#######2  | 88/121 [00:05<00:01, 16.53it/s]
0.031 rtbest='bnsh,btsh->bhtn':  74%|#######4  | 90/121 [00:05<00:01, 16.56it/s]
0.031 rtbest='bnsh,btsh->bhtn':  76%|#######6  | 92/121 [00:05<00:01, 16.66it/s]
0.031 rtbest='bnsh,btsh->bhtn':  78%|#######7  | 94/121 [00:05<00:01, 16.77it/s]
0.031 rtbest='bnsh,btsh->bhtn':  79%|#######9  | 96/121 [00:05<00:01, 16.76it/s]
0.031 rtbest='bnsh,btsh->bhtn':  81%|########  | 98/121 [00:05<00:01, 16.84it/s]
0.031 rtbest='bnsh,btsh->bhtn':  83%|########2 | 100/121 [00:05<00:01, 16.57it/s]
0.031 rtbest='bnsh,btsh->bhtn':  84%|########4 | 102/121 [00:06<00:01, 16.72it/s]
0.031 rtbest='bnsh,btsh->bhtn':  86%|########5 | 104/121 [00:06<00:01, 16.82it/s]
0.031 rtbest='bnsh,btsh->bhtn':  88%|########7 | 106/121 [00:06<00:00, 16.83it/s]
0.031 rtbest='bnsh,btsh->bhtn':  89%|########9 | 108/121 [00:06<00:00, 16.84it/s]
0.031 rtbest='bnsh,btsh->bhtn':  91%|######### | 110/121 [00:06<00:00, 16.79it/s]
0.031 rtbest='bnsh,btsh->bhtn':  93%|#########2| 112/121 [00:06<00:00, 16.76it/s]
0.031 rtbest='bnsh,btsh->bhtn':  94%|#########4| 114/121 [00:06<00:00, 16.42it/s]
0.031 rtbest='bnsh,btsh->bhtn':  96%|#########5| 116/121 [00:06<00:00, 16.52it/s]
0.031 rtbest='bnsh,btsh->bhtn':  98%|#########7| 118/121 [00:07<00:00, 16.66it/s]
0.031 rtbest='bnsh,btsh->bhtn':  99%|#########9| 120/121 [00:07<00:00, 16.65it/s]
0.031 rtbest='bnsh,btsh->bhtn': 100%|##########| 121/121 [00:07<00:00, 16.76it/s]

  0%|          | 0/10 [00:00<?, ?it/s]
 10%|#         | 1/10 [00:00<00:00,  9.90it/s]
 20%|##        | 2/10 [00:00<00:03,  2.04it/s]
 30%|###       | 3/10 [00:02<00:06,  1.07it/s]
 40%|####      | 4/10 [00:04<00:08,  1.46s/it]
 50%|#####     | 5/10 [00:08<00:11,  2.31s/it]
 60%|######    | 6/10 [00:14<00:13,  3.44s/it]
 70%|#######   | 7/10 [00:27<00:20,  6.73s/it]
 80%|########  | 8/10 [00:51<00:24, 12.17s/it]
 90%|######### | 9/10 [02:49<00:45, 45.31s/it]
100%|##########| 10/10 [05:19<00:00, 77.54s/it]
100%|##########| 10/10 [05:19<00:00, 31.92s/it]

Third equation: bhsn,bhtn->bnts#

The summation does not happen on the last axis but on the second one. It is worth transposing before multiplying. The decomposition of this equation without einsum function gives the following.

%0 0 input 0\nbhsn\n[ 0 1 3 2 -1] 139925674134832 id\nNone 0->139925674134832 139925674131520 expand_dims\naxes=((4, 4),)None 139925674134832->139925674131520 139925674133728 transpose - I0\nperm=(0, 3, 2, 4, 1)None 139925674131520->139925674133728 1 input 1\nbhtn\n[ 0 1 3 -1 2] 139925674133584 id\nNone 1->139925674133584 139925674133872 expand_dims\naxes=((3, 3),)None 139925674133584->139925674133872 139925674131808 transpose\nperm=(0, 4, 3, 2, 1)None 139925674133872->139925674131808 139925674134880 batch_dot\nbatch_axes=(0, 1) keep_axes=None left=(0, 1, 2) ndim=5 right=(0, 1, 3) sum_axes=(4,)None 139925674134352 transpose - I1\nperm=(0, 4, 1, 3, 2)None 139925674134880->139925674134352 139925674133728->139925674134880 139925674131808->139925674134880 139925674135456 squeeze\naxes=(1,)None 139925674132192 id - I-1\nNone 139925674135456->139925674132192 139925674134352->139925674135456
equation = "bhsn,bhtn->bnts"
df, piv, ax = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
dfs.append(df)
Einsum benchmark bhsn,bhtn->bnts -- (2, N, 12, 64) lower better, Einsum Speedup, baseline=numpy bhsn,bhtn->bnts -- (2, N, 12, 64) higher better

Out:

  0%|          | 0/121 [00:00<?, ?it/s]
0.032 rtbest='bhsn,bhtn->bnts':   0%|          | 0/121 [00:00<?, ?it/s]
0.032 rtbest='bhsn,bhtn->bnts':   2%|1         | 2/121 [00:00<00:07, 16.03it/s]
0.032 rtbest='bhsn,bhtn->bnts':   3%|3         | 4/121 [00:00<00:07, 16.52it/s]
0.032 rtbest='bhsn,bhtn->bnts':   5%|4         | 6/121 [00:00<00:06, 16.65it/s]
0.032 rtbest='bhsn,bhtn->bnts':   7%|6         | 8/121 [00:00<00:06, 16.74it/s]
0.032 rtbest='bhsn,bhtn->bnts':   8%|8         | 10/121 [00:00<00:06, 16.77it/s]
0.032 rtbest='btnh,btsh->bhsn':   8%|8         | 10/121 [00:00<00:06, 16.77it/s]
0.032 rtbest='btnh,btsh->bhsn':  10%|9         | 12/121 [00:00<00:06, 16.84it/s]
0.032 rtbest='bnhs,bnts->bsth':  10%|9         | 12/121 [00:00<00:06, 16.84it/s]
0.032 rtbest='bnhs,bnts->bsth':  12%|#1        | 14/121 [00:00<00:06, 16.49it/s]
0.032 rtbest='bnhs,bnts->bsth':  13%|#3        | 16/121 [00:00<00:06, 16.59it/s]
0.032 rtbest='bnhs,bnts->bsth':  15%|#4        | 18/121 [00:01<00:06, 16.63it/s]
0.032 rtbest='bnhs,bnts->bsth':  17%|#6        | 20/121 [00:01<00:06, 16.66it/s]
0.032 rtbest='bnhs,bnts->bsth':  18%|#8        | 22/121 [00:01<00:05, 16.71it/s]
0.032 rtbest='bnhs,bnts->bsth':  20%|#9        | 24/121 [00:01<00:05, 16.73it/s]
0.032 rtbest='bnhs,bnts->bsth':  21%|##1       | 26/121 [00:01<00:05, 16.74it/s]
0.032 rtbest='hbns,hbts->hstn':  21%|##1       | 26/121 [00:01<00:05, 16.74it/s]
0.032 rtbest='hbns,hbts->hstn':  23%|##3       | 28/121 [00:01<00:05, 16.44it/s]
0.032 rtbest='hbns,hbts->hstn':  25%|##4       | 30/121 [00:01<00:05, 16.51it/s]
0.032 rtbest='hbns,hbts->hstn':  26%|##6       | 32/121 [00:01<00:05, 16.51it/s]
0.032 rtbest='hbns,hbts->hstn':  28%|##8       | 34/121 [00:02<00:05, 16.49it/s]
0.032 rtbest='hbns,hbts->hstn':  30%|##9       | 36/121 [00:02<00:05, 16.46it/s]
0.032 rtbest='hbns,hbts->hstn':  31%|###1      | 38/121 [00:02<00:05, 16.49it/s]
0.032 rtbest='hbns,hbts->hstn':  33%|###3      | 40/121 [00:02<00:04, 16.53it/s]
0.032 rtbest='hbns,hbts->hstn':  35%|###4      | 42/121 [00:02<00:04, 16.56it/s]
0.032 rtbest='hbns,hbts->hstn':  36%|###6      | 44/121 [00:02<00:04, 16.26it/s]
0.032 rtbest='hbns,hbts->hstn':  38%|###8      | 46/121 [00:02<00:04, 16.35it/s]
0.032 rtbest='hbns,hbts->hstn':  40%|###9      | 48/121 [00:02<00:04, 16.38it/s]
0.032 rtbest='hbns,hbts->hstn':  41%|####1     | 50/121 [00:03<00:04, 16.38it/s]
0.032 rtbest='hbns,hbts->hstn':  43%|####2     | 52/121 [00:03<00:04, 16.42it/s]
0.032 rtbest='hbns,hbts->hstn':  45%|####4     | 54/121 [00:03<00:04, 16.45it/s]
0.032 rtbest='hbns,hbts->hstn':  46%|####6     | 56/121 [00:03<00:04, 15.89it/s]
0.032 rtbest='hbns,hbts->hstn':  48%|####7     | 58/121 [00:03<00:03, 15.76it/s]
0.032 rtbest='hbns,hbts->hstn':  50%|####9     | 60/121 [00:03<00:03, 15.91it/s]
0.032 rtbest='hbns,hbts->hstn':  51%|#####1    | 62/121 [00:03<00:03, 16.03it/s]
0.032 rtbest='hbns,hbts->hstn':  53%|#####2    | 64/121 [00:03<00:03, 16.15it/s]
0.032 rtbest='hbns,hbts->hstn':  55%|#####4    | 66/121 [00:04<00:03, 16.20it/s]
0.032 rtbest='hbns,hbts->hstn':  56%|#####6    | 68/121 [00:04<00:03, 16.28it/s]
0.032 rtbest='hbns,hbts->hstn':  58%|#####7    | 70/121 [00:04<00:03, 16.35it/s]
0.032 rtbest='hbns,hbts->hstn':  60%|#####9    | 72/121 [00:04<00:03, 16.07it/s]
0.032 rtbest='hbns,hbts->hstn':  61%|######1   | 74/121 [00:04<00:02, 16.20it/s]
0.032 rtbest='hbns,hbts->hstn':  63%|######2   | 76/121 [00:04<00:02, 16.33it/s]
0.032 rtbest='hbns,hbts->hstn':  64%|######4   | 78/121 [00:04<00:02, 16.42it/s]
0.032 rtbest='hbns,hbts->hstn':  66%|######6   | 80/121 [00:04<00:02, 16.51it/s]
0.032 rtbest='hbns,hbts->hstn':  68%|######7   | 82/121 [00:04<00:02, 16.52it/s]
0.032 rtbest='hbns,hbts->hstn':  69%|######9   | 84/121 [00:05<00:02, 16.55it/s]
0.032 rtbest='hbns,hbts->hstn':  71%|#######1  | 86/121 [00:05<00:02, 16.21it/s]
0.032 rtbest='hbns,hbts->hstn':  73%|#######2  | 88/121 [00:05<00:02, 16.24it/s]
0.032 rtbest='hbns,hbts->hstn':  74%|#######4  | 90/121 [00:05<00:01, 16.26it/s]
0.032 rtbest='hbns,hbts->hstn':  76%|#######6  | 92/121 [00:05<00:01, 16.34it/s]
0.032 rtbest='hbns,hbts->hstn':  78%|#######7  | 94/121 [00:05<00:01, 16.43it/s]
0.032 rtbest='hbns,hbts->hstn':  79%|#######9  | 96/121 [00:05<00:01, 16.42it/s]
0.032 rtbest='hbns,hbts->hstn':  81%|########  | 98/121 [00:05<00:01, 16.48it/s]
0.032 rtbest='hbns,hbts->hstn':  83%|########2 | 100/121 [00:06<00:01, 16.57it/s]
0.032 rtbest='hbns,hbts->hstn':  84%|########4 | 102/121 [00:06<00:01, 16.28it/s]
0.032 rtbest='hbns,hbts->hstn':  86%|########5 | 104/121 [00:06<00:01, 16.39it/s]
0.032 rtbest='hbns,hbts->hstn':  88%|########7 | 106/121 [00:06<00:00, 16.40it/s]
0.032 rtbest='hbns,hbts->hstn':  89%|########9 | 108/121 [00:06<00:00, 16.43it/s]
0.032 rtbest='hbns,hbts->hstn':  91%|######### | 110/121 [00:06<00:00, 16.44it/s]
0.032 rtbest='hbns,hbts->hstn':  93%|#########2| 112/121 [00:06<00:00, 16.41it/s]
0.032 rtbest='hbns,hbts->hstn':  94%|#########4| 114/121 [00:06<00:00, 16.42it/s]
0.032 rtbest='hbns,hbts->hstn':  96%|#########5| 116/121 [00:07<00:00, 16.16it/s]
0.032 rtbest='hbns,hbts->hstn':  98%|#########7| 118/121 [00:07<00:00, 16.27it/s]
0.032 rtbest='hbns,hbts->hstn':  99%|#########9| 120/121 [00:07<00:00, 16.28it/s]
0.032 rtbest='hbns,hbts->hstn': 100%|##########| 121/121 [00:07<00:00, 16.39it/s]

  0%|          | 0/10 [00:00<?, ?it/s]
 10%|#         | 1/10 [00:00<00:06,  1.36it/s]
 20%|##        | 2/10 [00:01<00:06,  1.28it/s]
 30%|###       | 3/10 [00:02<00:07,  1.07s/it]
 40%|####      | 4/10 [00:04<00:07,  1.31s/it]
 50%|#####     | 5/10 [00:06<00:07,  1.49s/it]
 60%|######    | 6/10 [00:08<00:07,  1.84s/it]
 70%|#######   | 7/10 [00:12<00:07,  2.37s/it]
 80%|########  | 8/10 [00:16<00:06,  3.02s/it]
 90%|######### | 9/10 [00:27<00:05,  5.31s/it]
100%|##########| 10/10 [00:37<00:00,  6.92s/it]
100%|##########| 10/10 [00:37<00:00,  3.77s/it]

Conclusion#

pytorch seems quite efficient on these examples. The custom implementation was a way to investigate the implementation of einsum and find some ways to optimize it.

merged = pandas.concat(dfs)
name = "einsum"
merged.to_csv("plot_%s.csv" % name, index=False)
merged.to_excel("plot_%s.xlsx" % name, index=False)
plt.savefig("plot_%s.png" % name)

plt.show()
plot op einsum

Total running time of the script: ( 8 minutes 46.672 seconds)

Gallery generated by Sphinx-Gallery