Optimize an existing ONNX model#
This page answers common “how do I…” questions for optimizing an
existing onnx.ModelProto with the pattern-based optimizer
provided by yobx.xoptim.
The optimizer searches for sequences of nodes matching predefined patterns and rewrites them into equivalent — but more efficient — ones (constant folding, operator fusion, transpose simplification, …) without changing the model inputs or outputs.
How to optimize a model with the default patterns#
Load the model into a yobx.xbuilder.GraphBuilder and call
to_onnx with
optimize=True. The default list of patterns is applied automatically.
<<<
from yobx.helpers.onnx_helper import pretty_onnx
from yobx.xbuilder import GraphBuilder
from yobx.doc import demo_mlp_model
onx = demo_mlp_model("temp_doc_optimize_mlp.onnx")
print("--- before optimization ---")
print(pretty_onnx(onx))
gr = GraphBuilder(onx, infer_shapes_options=True)
opt_onx = gr.to_onnx(optimize=True)
print("--- after optimization ---")
print(pretty_onnx(opt_onx))
>>>
--- before optimization ---
opset: domain='' version=18
input: name='x' type=dtype('float32') shape=[3, 10]
init: name='p_layers_0_weight::T10' type=float32 shape=(10, 32)
init: name='p_layers_2_weight::T10' type=float32 shape=(32, 1)
init: name='layers.0.bias' type=float32 shape=(32,)
init: name='layers.2.bias' type=float32 shape=(1,) -- array([-0.142], dtype=float32)
MatMul(x, p_layers_0_weight::T10) -> _onx_matmul_x
Add(_onx_matmul_x, layers.0.bias) -> linear
Relu(linear) -> relu
MatMul(relu, p_layers_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, layers.2.bias) -> output_0
output: name='output_0' type=dtype('float32') shape=[3, 1]
--- after optimization ---
opset: domain='' version=18
input: name='x' type=dtype('float32') shape=[3, 10]
init: name='layers.0.bias' type=float32 shape=(32,) -- GraphBuilder._update_structures_with_proto.1/from(layers.0.bias)
init: name='layers.2.bias' type=float32 shape=(1,) -- array([-0.142], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(layers.2.bias)
init: name='GemmTransposePattern--p_layers_0_weight::T10' type=float32 shape=(32, 10)-- GraphBuilder.constant_folding.from/fold(p_layers_0_weight::T10)##p_layers_0_weight::T10/GraphBuilder._update_structures_with_proto.1/from(p_layers_0_weight::T10)
init: name='GemmTransposePattern--p_layers_2_weight::T10' type=float32 shape=(1, 32)-- GraphBuilder.constant_folding.from/fold(init7_s2_1_32,p_layers_2_weight::T10)##p_layers_2_weight::T10/GraphBuilder._update_structures_with_proto.1/from(p_layers_2_weight::T10)##init7_s2_1_32/TransposeEqualReshapePattern.apply.new_shape
Gemm(x, GemmTransposePattern--p_layers_0_weight::T10, layers.0.bias, transB=1) -> linear
Relu(linear) -> relu
Gemm(relu, GemmTransposePattern--p_layers_2_weight::T10, layers.2.bias, transB=1) -> output_0
output: name='output_0' type=dtype('float32') shape=[3, 1]
The two MatMul + Add sequences are fused into Gemm nodes by the
default patterns.
How to choose which patterns to apply#
Use yobx.xbuilder.OptimizationOptions to enable or disable
patterns. Patterns can be passed as a list of names separated by
commas, or as a list of pattern instances.
There exist a few predefined lists:
default— patterns using only standard ONNX operators.onnxruntime— patterns specific to onnxruntime, the resulting model may usecom.microsoftoperators and may only run with onnxruntime.default+onnxruntime— both lists combined.
<<<
from yobx.helpers.onnx_helper import pretty_onnx
from yobx.xbuilder import GraphBuilder, OptimizationOptions
from yobx.doc import demo_mlp_model
onx = demo_mlp_model("temp_doc_optimize_mlp.onnx")
gr = GraphBuilder(
onx,
infer_shapes_options=True,
optimization_options=OptimizationOptions(
patterns="MatMulAdd,GemmTranspose", verbose=0
),
)
opt_onx = gr.to_onnx(optimize=True)
print(pretty_onnx(opt_onx))
>>>
opset: domain='' version=18
input: name='x' type=dtype('float32') shape=[3, 10]
init: name='layers.0.bias' type=float32 shape=(32,) -- GraphBuilder._update_structures_with_proto.1/from(layers.0.bias)
init: name='layers.2.bias' type=float32 shape=(1,) -- array([-0.142], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(layers.2.bias)
init: name='GemmTransposePattern--p_layers_0_weight::T10' type=float32 shape=(32, 10)-- GraphBuilder.constant_folding.from/fold(p_layers_0_weight::T10)##p_layers_0_weight::T10/GraphBuilder._update_structures_with_proto.1/from(p_layers_0_weight::T10)
init: name='GemmTransposePattern--p_layers_2_weight::T10' type=float32 shape=(1, 32)-- GraphBuilder.constant_folding.from/fold(p_layers_2_weight::T10)##p_layers_2_weight::T10/GraphBuilder._update_structures_with_proto.1/from(p_layers_2_weight::T10)
Gemm(x, GemmTransposePattern--p_layers_0_weight::T10, layers.0.bias, transB=1) -> linear
Relu(linear) -> relu
Gemm(relu, GemmTransposePattern--p_layers_2_weight::T10, layers.2.bias, transB=1) -> output_0
output: name='output_0' type=dtype('float32') shape=[3, 1]
The full list of available patterns is documented at Available Patterns.
How to inspect what the optimizer did#
Calling optimize
on the builder returns one row per applied rewriting, with timings and
the number of nodes added or removed. The rows can be aggregated with
pandas to get a per-pattern summary.
<<<
import pandas
from yobx.xbuilder import GraphBuilder, OptimizationOptions
from yobx.doc import demo_mlp_model
onx = demo_mlp_model("temp_doc_optimize_mlp.onnx")
gr = GraphBuilder(
onx,
infer_shapes_options=True,
optimization_options=OptimizationOptions(patterns="default"),
)
stat = gr.optimize()
df = pandas.DataFrame(stat)
for c in ["added", "removed"]:
df[c] = df[c].fillna(0).astype(int)
agg = df.groupby("pattern")[["added", "removed", "time_in"]].sum()
print(agg[(agg["added"] > 0) | (agg["removed"] > 0)])
>>>
added removed time_in
pattern
apply_GemmTransposePattern 4 2 0.000246
apply_MatMulAddPattern 2 4 0.000137
apply_TransposeEqualReshapePattern 1 1 0.000143
constant_folding 0 2 0.000301
optimization 0 2 0.009586
Setting verbose=1 (or higher) on
OptimizationOptions prints
the same information while the optimization runs.
Comparison with onnxscript rewriter#
The onnxscript package ships its own pattern-based rewriter (Pattern-based Rewrite Using Rules With onnxscript). Both tools serve the same purpose — rewriting an ONNX graph by matching sub-graphs and replacing them — but they differ in API and scope.
A typical onnxscript rewriter looks like this:
import onnx
from onnxscript.rewriter import pattern, rewrite
op = pattern.onnxop
def matmul_add_pattern(op, x, w, b):
t = op.MatMul(x, w)
return op.Add(t, b)
def gemm_replacement(op, x, w, b):
return op.Gemm(x, w, b)
rule = pattern.RewriteRule(matmul_add_pattern, gemm_replacement)
onx = onnx.load("model.onnx")
new_onx = rewrite(onx, pattern_rewrite_rules=[rule])
The equivalent with yobx reuses the built-in MatMulAdd pattern
already shipped with the optimizer:
import onnx
from yobx.xbuilder import GraphBuilder, OptimizationOptions
onx = onnx.load("model.onnx")
gr = GraphBuilder(
onx,
infer_shapes_options=True,
optimization_options=OptimizationOptions(patterns="MatMulAdd"),
)
new_onx = gr.to_onnx(optimize=True)
A user-defined rewrite is written as a subclass of
EasyPatternOptimization
(declarative match + apply) or
PatternOptimization (manual
match / apply). The example below rewrites MatMul + Add
into a custom op com.example.FusedGemm — the typical use-case
when a fused kernel is provided by a runtime extension. The new node
is created with g.anyop.<OpType>(..., domain=...), which is how
yobx emits non-standard ONNX operators (the same mechanism used
internally by the patterns in yobx.xoptim.patterns_ort to
target com.microsoft):
<<<
from yobx.helpers.onnx_helper import pretty_onnx
from yobx.xbuilder import GraphBuilder, OptimizationOptions
from yobx.xoptim import EasyPatternOptimization
from yobx.doc import demo_mlp_model
class MatMulAddToFusedGemmPattern(EasyPatternOptimization):
"""Fuses ``Add(MatMul(x, w), b)`` into a custom op
``com.example.FusedGemm(x, w, b)``."""
def match_pattern(self, g: "GraphBuilder", x, w, b):
t = g.op.MatMul(x, w)
return g.op.Add(t, b)
def apply_pattern(self, g: "GraphBuilder", x, w, b):
return g.anyop.FusedGemm(x, w, b, domain="com.example")
onx = demo_mlp_model("temp_doc_optimize_mlp.onnx")
gr = GraphBuilder(
onx,
infer_shapes_options=True,
optimization_options=OptimizationOptions(
patterns=[MatMulAddToFusedGemmPattern()],
),
)
opt_onx = gr.to_onnx(optimize=True)
print(pretty_onnx(opt_onx))
>>>
opset: domain='' version=18
opset: domain='com.example' version=1
input: name='x' type=dtype('float32') shape=[3, 10]
init: name='p_layers_0_weight::T10' type=float32 shape=(10, 32) -- GraphBuilder._update_structures_with_proto.1/from(p_layers_0_weight::T10)
init: name='p_layers_2_weight::T10' type=float32 shape=(32, 1) -- GraphBuilder._update_structures_with_proto.1/from(p_layers_2_weight::T10)
init: name='layers.0.bias' type=float32 shape=(32,) -- GraphBuilder._update_structures_with_proto.1/from(layers.0.bias)
init: name='layers.2.bias' type=float32 shape=(1,) -- array([-0.142], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(layers.2.bias)
FusedGemm[com.example](x, p_layers_0_weight::T10, layers.0.bias) -> linear
Relu(linear) -> relu
FusedGemm[com.example](relu, p_layers_2_weight::T10, layers.2.bias) -> output_0
output: name='output_0' type=dtype('float32') shape=[3, 1]
The resulting model contains com.example.FusedGemm nodes in place
of every MatMul + Add pair, and the corresponding opset import
(com.example) is added automatically by the builder.
The patterns argument of
OptimizationOptions
accepts a list mixing predefined names and user-defined instances —
e.g. patterns=["default", MatMulAddToFusedGemmPattern()] to combine
a custom rewrite with the built-in catalogue.
When the applicability of the fusion depends on shapes, dtypes or
attributes (for example FusedGemm only being valid for 2-D
inputs), subclass PatternOptimization directly and implement match
and apply as plain Python methods. The example below rewrites the
same MatMul + Add into com.example.FusedGemm but only when
both MatMul operands are rank 2 and the bias is rank 1 — a guard
that cannot be expressed in the declarative EasyPatternOptimization
API:
<<<
import inspect
from typing import List, Optional
from onnx import NodeProto
from yobx.helpers.onnx_helper import pretty_onnx
from yobx.xbuilder import GraphBuilder, OptimizationOptions
from yobx.xoptim import MatchResult, PatternOptimization
from yobx.doc import demo_mlp_model
class MatMulAddToFusedGemmManualPattern(PatternOptimization):
"""Fuses ``Add(MatMul(x, w), b)`` into ``com.example.FusedGemm``
when ``x`` and ``w`` are 2-D and ``b`` is 1-D."""
def match(
self,
g: "GraphBuilderPatternOptimization",
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Add" or node.domain != "":
return self.none()
matmul = g.node_before(node.input[0])
if matmul is None or matmul.op_type != "MatMul" or matmul.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if g.is_used_more_than_once(matmul.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
x, w = matmul.input
b = node.input[1]
if not all(g.has_rank(i) for i in (x, w, b)):
return self.none(node, inspect.currentframe().f_lineno)
if g.get_rank(x) != 2 or g.get_rank(w) != 2 or g.get_rank(b) != 1:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [matmul, node], self.apply)
def apply(
self,
g: "GraphBuilder",
matmul_node: NodeProto,
add_node: NodeProto,
) -> List[NodeProto]:
x, w = matmul_node.input
b = add_node.input[1]
new_node = g.make_node(
"FusedGemm",
[x, w, b],
add_node.output,
domain="com.example",
name=f"{self.__class__.__name__}--{add_node.name}",
doc_string=add_node.doc_string,
)
return [new_node]
onx = demo_mlp_model("temp_doc_optimize_mlp_manual.onnx")
gr = GraphBuilder(
onx,
infer_shapes_options=True,
optimization_options=OptimizationOptions(
patterns=[MatMulAddToFusedGemmManualPattern()],
),
)
opt_onx = gr.to_onnx(optimize=True)
print(pretty_onnx(opt_onx))
>>>
opset: domain='' version=18
opset: domain='com.example' version=1
input: name='x' type=dtype('float32') shape=[3, 10]
init: name='p_layers_0_weight::T10' type=float32 shape=(10, 32) -- GraphBuilder._update_structures_with_proto.1/from(p_layers_0_weight::T10)
init: name='p_layers_2_weight::T10' type=float32 shape=(32, 1) -- GraphBuilder._update_structures_with_proto.1/from(p_layers_2_weight::T10)
init: name='layers.0.bias' type=float32 shape=(32,) -- GraphBuilder._update_structures_with_proto.1/from(layers.0.bias)
init: name='layers.2.bias' type=float32 shape=(1,) -- array([-0.142], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(layers.2.bias)
FusedGemm[com.example](x, p_layers_0_weight::T10, layers.0.bias) -> linear
Relu(linear) -> relu
FusedGemm[com.example](relu, p_layers_2_weight::T10, layers.2.bias) -> output_0
output: name='output_0' type=dtype('float32') shape=[3, 1]
The real custom-op fusions in yobx.xoptim.patterns_ort
(com.microsoft.FusedMatMul, com.microsoft.Gelu, …) follow the
same template, with richer guards on shapes, dtypes and attributes.
Main differences:
Out-of-the-box catalog — both libraries ship a predefined set of rules.
yobxexposes its catalogue throughpatterns="default"(constant folding, transpose simplification, MatMul/Gemm fusions, …) with the full list documented in Available Patterns. onnxscript ships an equivalent collection accessible throughonnxscript.rewriter.rewrite(model)andonnxscript.optimizer.optimize(model). The two catalogues do not cover exactly the same rewrites, so they tend to be complementary rather than interchangeable.Granularity of the API — onnxscript rewrite rules are expressed as two ONNX functions (match + replacement).
yobxalso supports this style throughOnnxEasyPatternOptimization, but it additionally exposes a more imperative API based onPatternOptimizationwherematchandapplyare arbitrary Python methods. This is convenient when the rewriting condition depends on shapes, dtypes or attributes, which is harder to express purely structurally.Shape and type information — the
yobxmatcher runs after shape inference and can therefore filter matches on tensor ranks, shapes and dtypes through thegargument ofmatch. The onnxscript rewriter relies on acheckcallback for similar purposes.Variable-arity patterns — because
matchandapplyofPatternOptimizationare arbitrary Python methods, a singleyobxpattern can match a node whose number of inputs (or outputs) is not known at authoring time — typical examples areConcat,Sum,Min/Max, the variadicSlice(3-to-5 inputs), orDropout(with optionalratio/training_mode). See for instanceConcatGatherPattern,SliceSlicePatternandDropoutPattern. The onnxscript rewriter, in contrast, expresses the match graph as a fixed ONNX function, so a separate rule has to be written for each input arity.Diagnostics — every iteration records statistics (
added/removednodes,time_inper pattern), which makes it easy to spot which patterns actually fire and which ones are expensive. See How to inspect what the optimizer did above.Integration with conversion —
yobxruns the same optimizer automatically at the end of everyto_onnxcall (for exampleyobx.torch.to_onnx(),yobx.sklearn.to_onnx(),yobx.sql.to_onnx()). Optimizing a pre-existing ONNX file is exactly the same code path, just starting from anonnx.ModelProtoinstead of a framework model.
A note on performance#
The matching algorithm is roughly \(O(N \cdot P \cdot I)\) in the
number of nodes N, the number of patterns P and the number of
iterations I (see Pattern Optimizer). Two design
choices keep the constant factor small in practice:
Each pattern can declare its entry node operator type via
fast_op_type. When set, the optimizer indexes the graph once per iteration and only feeds the relevant nodes to the pattern, instead of iterating over the whole graph.The optimizer is incremental: at each iteration only nodes that could be affected by previously applied rewritings are revisited, and the loop stops as soon as no pattern fires.
In practice, optimizing a typical transformer-sized model with the
default list of patterns runs in seconds on a single CPU core. The
time_in column produced by optimize (see above)
gives a per-pattern budget that can be used to drop patterns whose
cost exceeds their benefit on a given model — this is exactly the
purpose of the DROPPATTERN environment variable documented in
GraphBuilderPatternOptimization.
See also
Pattern Optimizer — design and algorithm of the pattern optimizer.
Available Patterns — full list of patterns shipped with
yobx.Pattern-based Rewrite Using Rules With onnxscript — the onnxscript rewriter tutorial.