expressions.h#
Symbolic dimension-expression utilities for ONNX shape inference.
Provides a lightweight AST-based library for parsing, simplifying,
evaluating, and renaming symbolic shape expressions such as those produced
during ONNX shape inference (e.g. "2*batch//batch" → "batch").
See Expressions for the Python interface.
Symbolic dimension-expression utilities for ONNX shape inference.
Provides a lightweight AST-based library for parsing, simplifying, evaluating, and renaming symbolic shape expressions such as those produced during ONNX shape inference (e.g. "2*batch//2" → "batch").
-
namespace onnx_light#
-
namespace onnx_optim#
-
namespace expressions#
Typedefs
-
using NodePtr = std::unique_ptr<Node>#
Owning pointer to an AST node.
All AST construction and transformation functions return
NodePtrvalues; the tree owns its children and is freed when the rootNodePtrgoes out of scope.
-
using SimplifyResult = std::variant<int64_t, std::string>#
Return type of simplify_expression.
Holds either an
int64_twhen the expression reduces to a pure numeric constant, or astd::stringwhen symbolic variables remain after simplification.Use
std::holds_alternative<int64_t>(r)to check which case applies, or callsimplify_result_to_string()for a uniform string representation.SimplifyResult r = simplify_expression("2*batch//batch"); assert(std::holds_alternative<int64_t>(r)); assert(std::get<int64_t>(r) == 2); SimplifyResult s = simplify_expression("a + b"); assert(std::holds_alternative<std::string>(s)); assert(std::get<std::string>(s) == "a+b");
-
using DimType = std::variant<int64_t, std::string>#
A dimension value: either a concrete integer or a symbolic string.
int64_tis used when the dimension is statically known;std::stringwhen it is symbolic (e.g."batch"or"seq_length+1").DimType d1 = int64_t{64}; // concrete dimension DimType d2 = std::string{"N"}; // symbolic dimension
Enums
-
enum class BinOpKind#
Binary operator kind used in the expression AST.
The
^and&operators are borrowed from Python’s bitwise-xor / bitwise-and syntax and are re-interpreted asmaxandminin this expression system, matching the convention fromyobx/xexpressions.Values:
-
enumerator Add#
Addition:
a + b.
-
enumerator Sub#
Subtraction:
a - b.
-
enumerator Mult#
Multiplication:
a * b.
-
enumerator FloorDiv#
Floor (integer) division:
a // b.
-
enumerator Mod#
Modulo:
a % b.
-
enumerator BitXor#
Encodes
max(a, b)using the^syntax.
-
enumerator BitAnd#
Encodes
min(a, b)using the&syntax.
-
enumerator Add#
Functions
-
NodePtr parse(const std::string &expr)#
Parses
exprinto an AST.The grammar follows Python operator precedence for the supported subset of operators:
+,-,*,//,%,^(max),&(min), unary+/-, parentheses, and function calls with comma-separated arguments.Operator precedence (low to high):
^(BitXor / max)&(BitAnd / min)+,-*,//,%unary
-,+atoms (constants, names, parenthesised sub-expressions, calls)
auto tree = parse("2*batch//batch"); // tree is a BinOp(BinOp(Constant(2)*Name("batch")), FloorDiv, Name("batch"))
- Parameters:
expr – The expression string to parse.
- Throws:
std::runtime_error – if the input contains a lexical or syntax error.
- Returns:
An owning pointer to the root of the parsed AST.
-
std::string unparse(const Node &node)#
Converts
nodeback to a canonical string expression.Follows Python’s
ast.unparseparenthesisation rules: operator precedence determines where parentheses are inserted so that the output round-trips throughparse()to an equivalent AST.auto tree = parse("(a + b) * c"); std::string s = unparse(*tree); // s == "(a+b)*c"
- Parameters:
node – The root AST node to convert.
- Returns:
A string representation of the expression without extra spaces.
-
std::string simplify_result_to_string(const SimplifyResult &r)#
Returns a string representation of a SimplifyResult.
Converts an
int64_tresult to its decimal string; returns thestd::stringvariant unchanged.- Parameters:
r – The result to convert.
- Returns:
Decimal string for integer results; the simplified expression string otherwise.
-
SimplifyResult simplify_expression(const std::string &expr)#
Simplifies a symbolic or numeric expression string.
Applies a pipeline of AST transformations twice:
CeilToIntTransformer— expandsCeilToInt(x, n)to(x + n - 1) // n.SimpleSimplifyTransformer— foldsx^x → x,x + 0 → x,x * 1 → x.MulDivCancellerTransformer— cancels common symbolic factors, e.g.2*x//x → 2.ExactMulDivConstantFolderTransformer— folds1024*a//2→512*a.MaxToXorTransformer— rewritesMax(a,b)andmax(a,b)toa^b.ReorderCommutativeOpsTransformer— sorts+/*operands alphabetically.MaxIntTransformer— foldsint_const ^ int_consttomax(a, b).
A final linear-combination visitor then collects the result as a normalised sum of symbolic terms plus an integer constant.
// Fully numeric result: auto r1 = simplify_expression("2*batch//batch"); assert(std::get<int64_t>(r1) == 2); // Symbolic result: auto r2 = simplify_expression("a + b - a"); assert(std::get<std::string>(r2) == "b"); // CeilToInt expansion: auto r3 = simplify_expression("CeilToInt(b+c, 2)"); assert(std::get<std::string>(r3) == "(1+b+c)//2");
- Parameters:
expr – The expression string to simplify.
- Returns:
An
int64_twhen the result is fully numeric, or a simplifiedstd::stringotherwise. Returnsexprunchanged when it contains syntax that the parser does not recognise (e.g."::"in ONNX node names).
-
SimplifyResult simplify_expression(int64_t value)#
Returns the integer as-is (convenience overload for uniform call sites).
- Parameters:
value – An integer that is already fully simplified.
- Returns:
A
SimplifyResultholdingvalue.
-
std::map<std::string, int64_t> simplify_two_expressions(const std::string &expr1, const std::string &expr2)#
Returns the non-zero coefficient map of the difference
expr1- (expr2).Builds the combined expression
expr1 - (expr2), runs the linear-combination visitor, and returns only those variable coefficients that are non-zero. An empty map indicates that the two expressions are equal under linear arithmetic.auto diff = simplify_two_expressions("s52+seq_length", "s52+s70"); // diff == {{"s70", -1}, {"seq_length", 1}} auto same = simplify_two_expressions("e*2", "e+e"); // same is empty — the two expressions are equal
- Parameters:
expr1 – The first expression string.
expr2 – The second expression string.
- Returns:
A map from variable name (or sub-expression key) to its integer coefficient in
expr1 - expr2. Zero-coefficient terms are omitted.
-
int64_t evaluate_expression(const std::string &expr, const std::unordered_map<std::string, int64_t> &context)#
Evaluates
exprwith the variable assignments incontext.Supported constructs:
Signed 64-bit integer constants.
Variable references resolved via
context.Binary operators
+,-,*,//(floor division),%(modulo),^(max),&(min).Unary
-.CeilToInt(n, div)— ceiling division:(n % div == 0) ? n/div : n/div + 1.
int64_t v = evaluate_expression("x - y", {{"x", 5}, {"y", 6}}); // v == -1 int64_t c = evaluate_expression("CeilToInt(7, 2)", {}); // c == 4
- Parameters:
expr – The expression string to evaluate.
context – A map from variable name to its integer value.
- Throws:
std::runtime_error – if the expression has a syntax error, references an unknown variable, or contains an unsupported node type.
- Returns:
The integer result of evaluating the expression.
-
std::unordered_set<std::string> parse_expression_tokens(const std::string &expr)#
Returns the set of variable names referenced in
expr.Parses
exprand walks the AST to collect everyNamenode. If the expression has a syntax error the function returns{expr}(a set containing the original string), matching the Python reference behaviour.auto tokens = parse_expression_tokens("a + b * c"); // tokens == {"a", "b", "c"} auto bad = parse_expression_tokens("a +"); // bad == {"a +"} (syntax error → original string returned)
- Parameters:
expr – The expression string to scan.
- Returns:
An unordered set of variable name strings. Contains only
expritself when parsing fails.
-
std::string rename_expression(const std::string &expr, const std::unordered_map<std::string, std::string> &mapping)#
Renames variables in
expraccording tomapping.Also converts
Max(a, b)calls to thea^bxor form before renaming. The result has all spaces removed (matching the Python reference output).std::string r = rename_expression("s52 + seq_length", {{"s52", "B"}}); // r == "B+seq_length" std::string m = rename_expression("Max(s10, s3)", {{"s10", "E"}, {"s3", "D"}}); // m == "E^D" (Max is rewritten to ^ before renaming)
- Parameters:
expr – The expression string to rename.
mapping – A map from old variable name to new variable name.
- Throws:
std::runtime_error – if
exprcannot be parsed.- Returns:
The renamed expression string (no spaces).
-
std::string rename_dynamic_expression(const std::string &expression, const std::unordered_map<std::string, std::string> &replacements)#
Renames variables in
expressionusingreplacements, then simplifies.Applies the following pipeline in order:
Parse
expression.Rewrite
Max(a, b)→a ^ b.Apply the rename mapping.
Apply
SimpleSimplifyTransformer.Unparse and strip spaces.
Returns
expressionunchanged if it has a syntax error.std::string r = rename_dynamic_expression("s9+seq_length", {{"s9", "cache_length"}, {"seq_length", "seq_length"}}); // r == "cache_length+seq_length"
- Parameters:
expression – The expression string to transform.
replacements – A map from old variable name to new variable name.
- Returns:
The renamed and simplified expression string (no spaces), or
expressionunchanged on parse failure.
-
std::map<std::string, std::string> rename_dynamic_dimensions(const std::map<std::string, std::unordered_set<std::string>> &constraints, const std::unordered_set<std::string> &original, const std::string &ban_prefix = "DYN")#
Renames dynamic shape dimensions from internal names to user-visible ones.
Frameworks such as
torch.export.exportproduce many internal dimension names (e.g.s0,s1, …) for dynamic shapes. This function replaces them with the canonical names supplied by the user viaoriginal.The algorithm iterates over
constraints; for each entry it finds the intersection of the equivalent-name set withoriginal, picks the lexicographically smallest match as the canonical name, and propagates it to all aliases — unless the name starts withban_prefix.std::map<std::string, std::unordered_set<std::string>> constraints = { {"s0", {"batch", "s12"}}, {"s12", {"batch", "s0"}}, }; std::unordered_set<std::string> original = {"batch"}; auto renamed = rename_dynamic_dimensions(constraints, original); // renamed["s0"] == "batch" // renamed["s12"] == "batch"
- Parameters:
constraints – A map from each dimension name to the set of all dimension names that are known to be equal to it (i.e. the equivalence class).
original – The set of user-visible (preferred) dimension names.
ban_prefix – Names starting with this prefix are never selected as the canonical replacement (default:
"DYN").
- Returns:
A map
{internal_name → canonical_name}covering all names inoriginal(mapped to themselves) plus every name inconstraintsthat was successfully resolved.
-
std::string dim_to_string(const DimType &d)#
Returns a string representation of
d.Converts an
int64_tto its decimal string; returns thestd::stringvariant unchanged.- Parameters:
d – The dimension value to convert.
- Returns:
Decimal string for integer dimensions; the symbol string otherwise.
-
DimType dim_mul(const DimType &a, const DimType &b)#
Multiplies two dimensions.
Returns
a * bas anint64_twhen both operands are integers. Otherwise builds the expression"(a)*(b)"and simplifies it symbolically.dim_mul(DimType{int64_t{3}}, DimType{int64_t{4}}) == DimType{int64_t{12}}; // dim_mul("n", 2) returns a string containing "n" and "2"
- Parameters:
a – The first dimension (integer or symbolic string).
b – The second dimension (integer or symbolic string).
- Returns:
The product as an integer when both are concrete, or as a simplified string otherwise.
-
DimType dim_multi_mul(const std::vector<DimType> &args)#
Multiplies a sequence of dimensions.
Computes the product of all elements in
args. If every element is anint64_tthe result is an exact integer product; otherwise the expression"(a0)*(a1)*..."is built and simplified symbolically.dim_multi_mul({DimType{int64_t{2}}, DimType{int64_t{3}}, DimType{int64_t{4}}}) == DimType{int64_t{24}};
- Parameters:
args – Non-empty vector of dimensions. Returns
int64_t{1}for an empty vector.- Returns:
The product as an integer when all operands are concrete, or as a simplified string otherwise.
-
DimType dim_add(const DimType &a, const DimType &b)#
Adds two dimensions.
Returns
a + bas anint64_twhen both are integers; otherwise builds"(a)+(b)"and simplifies.dim_add(DimType{int64_t{3}}, DimType{int64_t{4}}) == DimType{int64_t{7}};
- Parameters:
a – The first dimension.
b – The second dimension.
- Returns:
The sum.
-
DimType dim_sub(const DimType &a, const DimType &b)#
Subtracts
bfroma.Returns
a - bas anint64_twhen both are integers; otherwise builds"(a)-(b)"and simplifies.dim_sub(DimType{int64_t{10}}, DimType{int64_t{3}}) == DimType{int64_t{7}};
- Parameters:
a – The minuend.
b – The subtrahend.
- Returns:
The difference.
-
DimType dim_div(const DimType &a, const DimType &b)#
Floor-divides
abyb.Assumes both values are non-negative (as is typical for ONNX shape dimensions). Returns
a // bas anint64_twhen both are integers; otherwise builds"(a)//(b)"and simplifies.dim_div(DimType{int64_t{7}}, DimType{int64_t{2}}) == DimType{int64_t{3}}; dim_div(DimType{std::string{"2*n"}}, DimType{int64_t{2}}) == DimType{std::string{"n"}};
- Parameters:
a – The dividend.
b – The divisor.
- Returns:
The floor-division result.
-
DimType dim_mod(const DimType &a, const DimType &b)#
Computes
amodulob.Returns
a % bas anint64_twhen both are integers; otherwise builds"(a)%(b)"and simplifies.dim_mod(DimType{int64_t{10}}, DimType{int64_t{3}}) == DimType{int64_t{1}};
- Parameters:
a – The dividend.
b – The divisor.
- Returns:
The remainder.
-
DimType dim_max(const DimType &a, const DimType &b)#
Returns the maximum of
aandb.Returns
max(a, b)as anint64_twhen both are integers; otherwise builds"(a)^(b)"(the xor encoding of max) and simplifies.dim_max(DimType{int64_t{7}}, DimType{int64_t{3}}) == DimType{int64_t{7}}; // dim_max("n", "n") simplifies to DimType{std::string{"n"}} (x^x → x)
- Parameters:
a – The first dimension.
b – The second dimension.
- Returns:
The maximum of the two dimensions.
-
DimType dim_min(const DimType &a, const DimType &b)#
Returns the minimum of
aandb.Returns
min(a, b)as anint64_twhen both are integers; otherwise builds"(a)&(b)"(the ampersand encoding of min) and simplifies.dim_min(DimType{int64_t{2}}, DimType{int64_t{9}}) == DimType{int64_t{2}};
- Parameters:
a – The first dimension.
b – The second dimension.
- Returns:
The minimum of the two dimensions.
-
struct BinOp : public onnx_light::onnx_optim::expressions::Node#
- #include <expressions.h>
Interior node representing a binary arithmetic operation.
auto b = std::make_unique<BinOp>( std::make_unique<Name>("a"), BinOpKind::Add, std::make_unique<Constant>(1)); // unparse(*b) == "a+1"
-
struct Call : public onnx_light::onnx_optim::expressions::Node#
- #include <expressions.h>
Interior node representing a function call (e.g.
CeilToInt,Max).The only function calls understood by
evaluate_expressionareCeilToInt(n, div), which computes ceiling division.Max(a, b)is syntactic sugar thatMaxToXorTransformerrewrites toa ^ bbefore evaluation.// After MaxToXorTransformer, Max(a, b) becomes BinOp(a, BitXor, b). // CeilToInt(n, 2) is evaluated as (n % 2 == 0) ? n/2 : n/2+1.
-
struct Constant : public onnx_light::onnx_optim::expressions::Node#
- #include <expressions.h>
Leaf node representing a signed 64-bit integer constant.
auto c = std::make_unique<Constant>(42); // unparse(*c) == "42"
Public Functions
Public Members
-
int64_t value#
The integer value of this constant.
-
int64_t value#
-
struct Name : public onnx_light::onnx_optim::expressions::Node#
- #include <expressions.h>
Leaf node representing a symbolic variable reference.
auto n = std::make_unique<Name>("batch"); // unparse(*n) == "batch"
-
struct Node#
- #include <expressions.h>
Abstract base class for all expression AST nodes.
Every concrete node type derives from
Nodeand overridesclone()to produce a deep copy. Nodes are always heap-allocated and owned byNodePtr.Subclassed by onnx_light::onnx_optim::expressions::BinOp, onnx_light::onnx_optim::expressions::Call, onnx_light::onnx_optim::expressions::Constant, onnx_light::onnx_optim::expressions::Name, onnx_light::onnx_optim::expressions::UnaryOp
-
struct UnaryOp : public onnx_light::onnx_optim::expressions::Node#
- #include <expressions.h>
Interior node representing a unary arithmetic operation.
auto u = std::make_unique<UnaryOp>(UnaryOpKind::USub, std::make_unique<Name>("x")); // unparse(*u) == "-x"
Public Functions
-
inline UnaryOp(UnaryOpKind o, NodePtr n)#
Constructs a UnaryOp node.
- Parameters:
o – The operator kind.
n – The operand (ownership transferred).
-
inline UnaryOp(UnaryOpKind o, NodePtr n)#
-
using NodePtr = std::unique_ptr<Node>#
-
namespace expressions#
-
namespace onnx_optim#