Expressions#
Symbolic dimension-expression utilities exposed by the
onnx_light.onnx_optim.expressions module. See
Symbolic expression library (onnx_light.onnx_optim.expressions) for the design overview.
Operator tokens#
The parser recognises the following operator tokens (sorted alphabetically):
%— modulo&— encodesmin(a, b)(,)— grouping*— multiplication+— addition (also unary plus),— argument separator-— subtraction (also unary minus)//— floor (integer) division^— encodesmax(a, b)
The built-in function tokens are CeilToInt, Max and max.
Available simplifications#
simplify_expression() applies the following transformers (in order),
twice, followed by a final linear-combination pass:
CeilToIntTransformer— rewritesCeilToInt(x, n)→(x + n - 1) // n.MaxToXorTransformer— rewritesMax(a, b)andmax(a, b)→a ^ b.SimpleSimplifyTransformer— folds identities such asx ^ x → x,x + 0 → x,x * 1 → x,0 * x → 0.MulDivCancellerTransformer— cancels common symbolic factors in*///chains (e.g.2*x//x → 2).ExactMulDivConstantFolderTransformer— folds integer constants in*///chains when the division is exact (e.g.1024*a//2 → 512*a).ReorderCommutativeOpsTransformer— sorts operands of+and*alphabetically so equivalent expressions share a canonical form.MaxIntTransformer— evaluatesint_const ^ int_constat compile time.
Members#
Symbolic dimension expression utilities backed by a C++ library.
This package exposes an AST-based expression engine for simplifying, evaluating, and renaming symbolic shape expressions such as those produced during ONNX model shape inference.
Expressions are strings containing integer constants, symbolic variable names
(e.g. "batch", "seq_length"), and the arithmetic operators +,
-, *, // (floor division), % (modulo), ^ (max), and
& (min).
Typical usage:
from onnx_light.onnx_optim.expressions import (
simplify_expression,
evaluate_expression,
rename_expression,
dim_add,
)
# Simplify a symbolic expression.
result = simplify_expression("2*batch//batch")
# result is 2 (int)
result = simplify_expression("a + b - a")
# result is "b"
# Evaluate with concrete variable assignments.
value = evaluate_expression("x + y", {"x": 3, "y": 5})
# value is 8
# Rename dimension variables.
expr = rename_expression("s0 + seq_len", {"s0": "batch"})
# expr is "batch+seq_len"
# Arithmetic on dimensions (int or str).
d = dim_add("batch", 1)
# d is "1+batch"
The module is exposed as onnx_light.onnx_optim.expressions.
- onnx_light.onnx_optim.expressions.dim_add(a: int | str, b: int | str) int | str#
Adds two dimensions.
Returns
a + bas anintwhen both operands are integers; otherwise builds"(a)+(b)"and simplifies symbolically.- Parameters:
a – The first dimension (integer or symbolic string).
b – The second dimension (integer or symbolic string).
- Returns:
The sum as an
intwhen both are concrete, or as a simplifiedstrotherwise.- Return type:
Examples:
>>> dim_add(3, 4) 7 >>> dim_add("n", 1) 'n+1'
- onnx_light.onnx_optim.expressions.dim_div(a: int | str, b: int | str) int | str#
Floor-divides a by b.
Returns
a // bas anintwhen both operands are integers; otherwise builds"(a)//(b)"and simplifies symbolically.- Parameters:
a – The dividend.
b – The divisor.
- Returns:
The floor-division result as an
intwhen both are concrete, or as a simplifiedstrotherwise.- Return type:
Examples:
>>> dim_div(12, 4) 3 >>> dim_div(7, 2) 3 >>> dim_div("2*n", 2) 'n'
- onnx_light.onnx_optim.expressions.dim_max(a: int | str, b: int | str) int | str#
Returns the maximum of two dimensions.
Returns
max(a, b)as anintwhen both operands are integers; otherwise builds"(a)^(b)"(the^encoding of max) and simplifies symbolically.- Parameters:
a – The first dimension.
b – The second dimension.
- Returns:
The maximum as an
intwhen both are concrete, or as a simplifiedstrotherwise.- Return type:
Examples:
>>> dim_max(7, 3) 7 >>> dim_max(2, 9) 9 >>> str(dim_max("n", "n")) 'n'
- onnx_light.onnx_optim.expressions.dim_min(a: int | str, b: int | str) int | str#
Returns the minimum of two dimensions.
Returns
min(a, b)as anintwhen both operands are integers; otherwise builds"(a)&(b)"(the&encoding of min) and simplifies symbolically.- Parameters:
a – The first dimension.
b – The second dimension.
- Returns:
The minimum as an
intwhen both are concrete, or as a simplifiedstrotherwise.- Return type:
Examples:
>>> dim_min(2, 9) 2 >>> dim_min(8, 3) 3
- onnx_light.onnx_optim.expressions.dim_mod(a: int | str, b: int | str) int | str#
Computes a modulo b.
Returns
a % bas anintwhen both operands are integers; otherwise builds"(a)%(b)"and simplifies symbolically.- Parameters:
a – The dividend.
b – The divisor.
- Returns:
The remainder as an
intwhen both are concrete, or as a simplifiedstrotherwise.- Return type:
Examples:
>>> dim_mod(10, 3) 1 >>> dim_mod(12, 4) 0
- onnx_light.onnx_optim.expressions.dim_mul(a: int | str, b: int | str) int | str#
Multiplies two dimensions.
Returns
a * bas anintwhen both operands are integers; otherwise builds"(a)*(b)"and simplifies symbolically.- Parameters:
a – The first dimension.
b – The second dimension.
- Returns:
The product as an
intwhen both are concrete, or as a simplifiedstrotherwise.- Return type:
Examples:
>>> dim_mul(3, 4) 12 >>> dim_mul(0, 5) 0
- onnx_light.onnx_optim.expressions.dim_multi_mul(*args: int | str) int | str#
Multiplies a sequence of dimensions.
Computes the product of all positional arguments. If every argument is an
intthe result is an exact integer product; otherwise the expression is built and simplified symbolically. Returns1for an empty argument list.- Parameters:
args – Zero or more dimension values (each an
intorstr).- Returns:
The product as an
intwhen all operands are concrete, or as a simplifiedstrotherwise.- Return type:
Examples:
>>> dim_multi_mul(2, 3, 4) 24 >>> dim_multi_mul(7) 7
- onnx_light.onnx_optim.expressions.dim_sub(a: int | str, b: int | str) int | str#
Subtracts b from a.
Returns
a - bas anintwhen both operands are integers; otherwise builds"(a)-(b)"and simplifies symbolically.- Parameters:
a – The minuend.
b – The subtrahend.
- Returns:
The difference as an
intwhen both are concrete, or as a simplifiedstrotherwise.- Return type:
Examples:
>>> dim_sub(10, 3) 7 >>> str(dim_sub("n", "n")) '0'
- onnx_light.onnx_optim.expressions.evaluate_expression(expression: str, context: dict[str, int]) int#
Evaluates an expression given variable assignments.
Supports signed 64-bit integer constants, variable references resolved via context, binary operators
+,-,*,//(floor division),%(modulo),^(max),&(min), unary-, and the built-inCeilToInt(n, div)function (ceiling division).- Parameters:
expression – The expression string to evaluate.
context – A mapping from variable name to its integer value.
- Returns:
The integer result of evaluating the expression.
- Return type:
- Raises:
RuntimeError – If the expression has a syntax error, references an unknown variable, or contains an unsupported construct.
Examples:
>>> evaluate_expression("x - y", {"x": 5, "y": 6}) -1 >>> evaluate_expression("-x", {"x": 5}) -5 >>> evaluate_expression("CeilToInt(7, 2)", {}) 4
- onnx_light.onnx_optim.expressions.parse_expression_tokens(expr: str) set[str]#
Returns the set of variable names referenced in expr.
Parses expr and collects every variable name (
NameAST node). If the expression has a syntax error the function returns{expr}(a set containing the original string), matching the reference behaviour.- Parameters:
expr – The expression string to scan.
- Returns:
An unordered set of variable name strings. Contains only expr itself when parsing fails.
- Return type:
Examples:
>>> sorted(parse_expression_tokens("a + b * c")) ['a', 'b', 'c'] >>> parse_expression_tokens("a +") {'a +'}
- onnx_light.onnx_optim.expressions.rename_dynamic_expression(expression: str, replacements: dict[str, str]) str#
Renames variables in expression and simplifies the result.
Applies the following pipeline in order:
Parse expression.
Rewrite
Max(a, b)→a ^ b.Apply the replacements mapping.
Apply a lightweight simplification pass.
Unparse and strip spaces.
Returns expression unchanged if it has a syntax error.
- Parameters:
expression – The expression string to transform.
replacements – A mapping from old variable name to new variable name.
- Returns:
The renamed and simplified expression string (no spaces), or expression unchanged on parse failure.
- Return type:
Examples:
>>> rename_dynamic_expression("s9+seq_length", ... {"s9": "cache_length", "seq_length": "seq_length"}) 'cache_length+seq_length' >>> rename_dynamic_expression("a +", {"a": "b"}) 'a +'
- onnx_light.onnx_optim.expressions.rename_expression(expr: str, mapping: dict[str, str]) str#
Renames variables in expr according to mapping.
Also converts
Max(a, b)calls to thea^bform before renaming so that composite keys such as"E^D"can be matched. The result has all spaces removed.- Parameters:
expr – The expression string to rename.
mapping – A mapping from old variable name to new variable name.
- Returns:
The renamed expression string (no spaces).
- Return type:
- Raises:
RuntimeError – If expr cannot be parsed.
Examples:
>>> rename_expression("s52+seq_length", {"s52": "B"}) 'B+seq_length' >>> rename_expression("Max(s10, s3)", {"s10": "E", "s3": "D"}) 'E^D'
- onnx_light.onnx_optim.expressions.simplify_expression(expr: str | int) str | int#
Simplifies a symbolic or numeric expression.
Applies a pipeline of AST transformations:
CeilToInt(x, n)is expanded to(x + n - 1) // n.Identity folds:
x^x → x,x + 0 → x,x * 1 → x.Common symbolic factors in
*///chains are cancelled (e.g.2*x//x → 2).Integer constants in mul/div chains are folded when the division is exact (e.g.
1024*a//2 → 512*a).Max(a, b)andmax(a, b)are rewritten toa^b.Commutative
+and*chains are sorted alphabetically.int_const ^ int_constis evaluated asmax.
The pipeline is applied twice to allow multi-step cancellations. A final linear-combination visitor normalises the remaining sum of terms.
- Parameters:
expr – An expression string (e.g.
"2*batch//batch") or an integer that is already simplified.- Returns:
An
intwhen the expression reduces to a numeric constant, or a simplifiedstrotherwise. Returnsexprunchanged when it contains syntax that the parser does not recognise.- Return type:
Examples:
>>> simplify_expression("2*batch//batch") 2 >>> simplify_expression("a + b - a") 'b' >>> simplify_expression("5 + x - 2 + 3") 'x+6' >>> simplify_expression("CeilToInt(b+c, 2)") '(1+b+c)//2' >>> simplify_expression("1024*a//2") '512*a' >>> simplify_expression("b + a") 'a+b'
- onnx_light.onnx_optim.expressions.simplify_two_expressions(expr1: str, expr2: str) dict[str, int]#
Returns the non-zero coefficient map of
expr1 - expr2.Builds the expression
expr1 - (expr2), runs a linear-combination visitor, and returns only those variable coefficients that are non-zero. An empty dict indicates that the two expressions are equal under linear arithmetic.- Parameters:
expr1 – The first expression string.
expr2 – The second expression string.
- Returns:
A dict mapping each variable name (or sub-expression key) to its integer coefficient in
expr1 - expr2. Zero-coefficient terms are omitted.- Return type:
Examples:
>>> simplify_two_expressions("s52+seq_length", "s52+s70") {'s70': -1, 'seq_length': 1} >>> simplify_two_expressions("e*2", "e+e") {}