Symbolic expression library (onnx_light.onnx_optim.expressions)#
This page describes the design of the symbolic dimension-expression library
introduced in onnx_light/onnx_optim/ and exposed as the
Python module onnx_light.onnx_optim.expressions.
The library was ported from yobx/xexpressions and re-implemented in C++ for speed and to avoid a runtime Python dependency in the shape-inference path.
Motivation#
ONNX shape inference and model transformation frequently need to manipulate
symbolic dimension expressions — strings such as "2*batch//batch",
"CeilToInt(seq_len, 8)", or "cache_length + seq_length" — that encode
relationships between dynamic tensor dimensions.
A pure-string approach (regex substitution, eval) is fragile. An
AST-based approach allows systematic:
Simplification —
2*batch//batch→2,a + b - a→b.Evaluation — substitute concrete integer values and compute the result.
Renaming — replace internal names (
s0,s1, …) with user-visible names (batch,seq_length, …).Arithmetic — add, subtract, multiply, divide, and compare symbolic dimension values without losing the symbolic form when the result is still symbolic.
Expression grammar#
The parser accepts a subset of Python arithmetic expressions:
expr ::= xor_expr ('^' xor_expr)*
xor_expr ::= and_expr ('&' and_expr)*
and_expr ::= add_expr (('+' | '-') add_expr)*
add_expr ::= mul_expr (('*' | '//' | '%') mul_expr)*
mul_expr ::= unary
unary ::= ('+' | '-') unary | atom
atom ::= INTEGER | NAME | '(' expr ')' | NAME '(' arg_list ')'
arg_list ::= expr (',' expr)*
Operator precedence (low → high):
Operator |
Meaning |
|---|---|
|
Encodes |
|
Encodes |
|
Addition |
|
Subtraction |
|
Multiplication |
|
Floor (integer) division |
|
Modulo |
Unary |
Unary |
^ and & borrow Python’s bitwise-xor and bitwise-and syntax and
re-interpret them as max and min respectively, following the
yobx/xexpressions convention. This lets the simplifier represent max/min
without adding new operator tokens.
The function call syntax Name(arg, ...) is used for two built-in functions:
CeilToInt(n, div)— ceiling division; the simplifier expands it to(n + div - 1) // divbefore other passes run.Max(a, b)/max(a, b)— rewritten toa^bbefore evaluation.
AST node types#
The parsed expression is represented as a tree of Node sub-classes:
Type |
Description |
|---|---|
|
Leaf: a signed 64-bit integer literal. |
|
Leaf: a symbolic variable reference (e.g. |
|
Interior: left |
|
Interior: unary |
|
Interior: a named function call with a list of argument sub-trees. |
All nodes are heap-allocated and owned through NodePtr
(std::unique_ptr<Node>). Every node provides a virtual clone()
method for deep copying.
Simplification pipeline#
simplify_expression() applies a fixed
sequence of AST transformers, then runs the sequence a second time to allow
multi-step cancellations to converge.
Each transformer is a pure tree-to-tree rewrite that produces a new
NodePtr without mutating the original:
Transformer |
What it does |
|---|---|
|
Rewrites |
|
Rewrites |
|
Folds identities: |
|
Collects all factors in a |
|
Folds integer constants in |
|
Sorts operands of |
|
Evaluates |
After two passes of this pipeline, a final
ExpressionSimplifierAddVisitor walks the tree and collects a linear
combination {symbol → coefficient}. This lets a + b - a simplify
to b and 3*x + 2*x simplify to 5*x even across multiple
transformer passes.
If the linear combination reduces to a pure integer constant (no remaining
symbolic terms), the result is returned as an int64_t; otherwise the
normalised sum is unparsed back to a string.
Unparser#
simplify_expression() (and all other
functions that produce an expression string) use unparse() to convert
an AST back to a string. The unparser inserts parentheses exactly where
required by the precedence rules above, so the output round-trips through
parse() to an equivalent AST.
For example:
from onnx_light.onnx_optim.expressions import simplify_expression
simplify_expression("(a + b) * c") # "(a+b)*c" — parens kept (needed)
simplify_expression("a * b + c") # "a*b+c" — no parens (not needed)
Dimension operations#
The DimType type represents a tensor
dimension as either a concrete int or a symbolic str. The dimension
operation functions (dim_add(),
dim_sub(),
dim_mul(),
dim_div(),
dim_mod(),
dim_max(),
dim_min(),
dim_multi_mul()) share a common pattern:
If both operands are integers, compute the result exactly and return an integer.
Otherwise, build an expression string
"(a) op (b)", callsimplify_expression(), and return the result (still an integer if the simplifier reduces it fully, otherwise a string).
This ensures that symbolic arithmetic never accumulates unnormalised intermediate expressions:
from onnx_light.onnx_optim.expressions import dim_add, dim_mul, dim_div
dim_add("batch", 1) # "1+batch"
dim_mul(2, "seq_length") # "2*seq_length"
dim_div("2*seq_length", 2) # "seq_length" (simplified)
dim_div("2*n", "n") # 2 (int — fully reduced)
Renaming#
Two renaming functions cover different use cases:
rename_expression()Renames variable names according to a mapping, also converting
Max(a, b)toa^bbeforehand. RaisesRuntimeErroron parse failure. Intended for deterministic, one-shot renames where a parse error is truly unexpected.rename_dynamic_expression()Like
rename_expression, but also applies a lightweight simplification pass and silently returns the original string on parse failure. Intended for best-effort renaming during shape inference where the expression may occasionally be a raw ONNX node name rather than a real expression.rename_dynamic_dimensions()Higher-level helper: given a set of equivalence classes (dimension names that are known to be equal to each other) and a set of user-visible preferred names, it produces a mapping from all internal names to their canonical user-visible equivalents. Names starting with a configurable ban prefix (default
"DYN") are never selected as canonical targets.
Build layout#
The expressions library is compiled as part of the onnx_optim CMake
STATIC target (lib_onnx_optim). lib_onnx_optim depends on
lib_onnx_op and is linked into the Python extension so that the
expressions code is available to all callers that consume the optimisation
library.
The C++ header and implementation files live in:
onnx_light/onnx_optim/
├── expressions.h ← public API (AST types + all free functions)
└── expressions.cc ← full implementation (tokenizer, parser,
transformers, evaluator, unparser)
The Python module onnx_light.onnx_optim.expressions wraps the C++ functions
exposed via the _onnxpy.expressions nanobind submodule (defined in
onnx_light/onnx_py/_onnxpy_submodules.cc).
Python wrapper:
onnx_light/onnx_optim/expressions.py ← documented Python wrappers
API reference#
C++ API: expressions.h
Python API: Expressions