expressions.h#

Symbolic dimension-expression utilities for ONNX shape inference.

Provides a lightweight AST-based library for parsing, simplifying, evaluating, and renaming symbolic shape expressions such as those produced during ONNX shape inference (e.g. "2*batch//batch""batch").

See Expressions for the Python interface.

Symbolic dimension-expression utilities for ONNX shape inference.

Provides a lightweight AST-based library for parsing, simplifying, evaluating, and renaming symbolic shape expressions such as those produced during ONNX shape inference (e.g. "2*batch//2""batch").

namespace onnx_light#
namespace onnx_optim#
namespace expressions#

Typedefs

using NodePtr = std::unique_ptr<Node>#

Owning pointer to an AST node.

All AST construction and transformation functions return NodePtr values; the tree owns its children and is freed when the root NodePtr goes out of scope.

using SimplifyResult = std::variant<int64_t, std::string>#

Return type of simplify_expression.

Holds either an int64_t when the expression reduces to a pure numeric constant, or a std::string when symbolic variables remain after simplification.

Use std::holds_alternative<int64_t>(r) to check which case applies, or call simplify_result_to_string() for a uniform string representation.

SimplifyResult r = simplify_expression("2*batch//batch");
assert(std::holds_alternative<int64_t>(r));
assert(std::get<int64_t>(r) == 2);

SimplifyResult s = simplify_expression("a + b");
assert(std::holds_alternative<std::string>(s));
assert(std::get<std::string>(s) == "a+b");
using DimType = std::variant<int64_t, std::string>#

A dimension value: either a concrete integer or a symbolic string.

int64_t is used when the dimension is statically known; std::string when it is symbolic (e.g. "batch" or "seq_length+1").

DimType d1 = int64_t{64};       // concrete dimension
DimType d2 = std::string{"N"};  // symbolic dimension

Enums

enum class BinOpKind#

Binary operator kind used in the expression AST.

The ^ and & operators are borrowed from Python’s bitwise-xor / bitwise-and syntax and are re-interpreted as max and min in this expression system, matching the convention from yobx/xexpressions.

Values:

enumerator Add#

Addition: a + b.

enumerator Sub#

Subtraction: a - b.

enumerator Mult#

Multiplication: a * b.

enumerator FloorDiv#

Floor (integer) division: a // b.

enumerator Mod#

Modulo: a % b.

enumerator BitXor#

Encodes max(a, b) using the ^ syntax.

enumerator BitAnd#

Encodes min(a, b) using the & syntax.

enum class UnaryOpKind#

Unary operator kind used in the expression AST.

Values:

enumerator USub#

Unary minus: -a.

enumerator UAdd#

Unary plus: +a (identity).

Functions

NodePtr parse(const std::string &expr)#

Parses expr into an AST.

The grammar follows Python operator precedence for the supported subset of operators: +, -, *, //, %, ^ (max), & (min), unary +/-, parentheses, and function calls with comma-separated arguments.

Operator precedence (low to high):

  • ^ (BitXor / max)

  • & (BitAnd / min)

  • +, -

  • *, //, %

  • unary -, +

  • atoms (constants, names, parenthesised sub-expressions, calls)

auto tree = parse("2*batch//batch");
// tree is a BinOp(BinOp(Constant(2)*Name("batch")), FloorDiv, Name("batch"))

Parameters:

expr – The expression string to parse.

Throws:

std::runtime_error – if the input contains a lexical or syntax error.

Returns:

An owning pointer to the root of the parsed AST.

std::string unparse(const Node &node)#

Converts node back to a canonical string expression.

Follows Python’s ast.unparse parenthesisation rules: operator precedence determines where parentheses are inserted so that the output round-trips through parse() to an equivalent AST.

auto tree = parse("(a + b) * c");
std::string s = unparse(*tree);
// s == "(a+b)*c"

Parameters:

node – The root AST node to convert.

Returns:

A string representation of the expression without extra spaces.

std::string simplify_result_to_string(const SimplifyResult &r)#

Returns a string representation of a SimplifyResult.

Converts an int64_t result to its decimal string; returns the std::string variant unchanged.

Parameters:

r – The result to convert.

Returns:

Decimal string for integer results; the simplified expression string otherwise.

SimplifyResult simplify_expression(const std::string &expr)#

Simplifies a symbolic or numeric expression string.

Applies a pipeline of AST transformations twice:

  1. CeilToIntTransformer — expands CeilToInt(x, n) to (x + n - 1) // n.

  2. SimpleSimplifyTransformer — folds x^x x, x + 0 x, x * 1 x.

  3. MulDivCancellerTransformer — cancels common symbolic factors, e.g. 2*x//x 2.

  4. ExactMulDivConstantFolderTransformer — folds 1024*a//2512*a.

  5. MaxToXorTransformer — rewrites Max(a,b) and max(a,b) to a^b.

  6. ReorderCommutativeOpsTransformer — sorts +/* operands alphabetically.

  7. MaxIntTransformer — folds int_const ^ int_const to max(a, b).

A final linear-combination visitor then collects the result as a normalised sum of symbolic terms plus an integer constant.

// Fully numeric result:
auto r1 = simplify_expression("2*batch//batch");
assert(std::get<int64_t>(r1) == 2);

// Symbolic result:
auto r2 = simplify_expression("a + b - a");
assert(std::get<std::string>(r2) == "b");

// CeilToInt expansion:
auto r3 = simplify_expression("CeilToInt(b+c, 2)");
assert(std::get<std::string>(r3) == "(1+b+c)//2");

Parameters:

expr – The expression string to simplify.

Returns:

An int64_t when the result is fully numeric, or a simplified std::string otherwise. Returns expr unchanged when it contains syntax that the parser does not recognise (e.g. "::" in ONNX node names).

SimplifyResult simplify_expression(int64_t value)#

Returns the integer as-is (convenience overload for uniform call sites).

Parameters:

value – An integer that is already fully simplified.

Returns:

A SimplifyResult holding value.

std::map<std::string, int64_t> simplify_two_expressions(const std::string &expr1, const std::string &expr2)#

Returns the non-zero coefficient map of the difference expr1 - (expr2).

Builds the combined expression expr1 - (expr2), runs the linear-combination visitor, and returns only those variable coefficients that are non-zero. An empty map indicates that the two expressions are equal under linear arithmetic.

auto diff = simplify_two_expressions("s52+seq_length", "s52+s70");
// diff == {{"s70", -1}, {"seq_length", 1}}

auto same = simplify_two_expressions("e*2", "e+e");
// same is empty — the two expressions are equal

Parameters:
  • expr1 – The first expression string.

  • expr2 – The second expression string.

Returns:

A map from variable name (or sub-expression key) to its integer coefficient in expr1 - expr2. Zero-coefficient terms are omitted.

int64_t evaluate_expression(const std::string &expr, const std::unordered_map<std::string, int64_t> &context)#

Evaluates expr with the variable assignments in context.

Supported constructs:

  • Signed 64-bit integer constants.

  • Variable references resolved via context.

  • Binary operators +, -, *, // (floor division), % (modulo), ^ (max), & (min).

  • Unary -.

  • CeilToInt(n, div) — ceiling division: (n % div == 0) ? n/div : n/div + 1.

int64_t v = evaluate_expression("x - y", {{"x", 5}, {"y", 6}});
// v == -1

int64_t c = evaluate_expression("CeilToInt(7, 2)", {});
// c == 4

Parameters:
  • expr – The expression string to evaluate.

  • context – A map from variable name to its integer value.

Throws:

std::runtime_error – if the expression has a syntax error, references an unknown variable, or contains an unsupported node type.

Returns:

The integer result of evaluating the expression.

std::unordered_set<std::string> parse_expression_tokens(const std::string &expr)#

Returns the set of variable names referenced in expr.

Parses expr and walks the AST to collect every Name node. If the expression has a syntax error the function returns {expr} (a set containing the original string), matching the Python reference behaviour.

auto tokens = parse_expression_tokens("a + b * c");
// tokens == {"a", "b", "c"}

auto bad = parse_expression_tokens("a +");
// bad == {"a +"} (syntax error → original string returned)

Parameters:

expr – The expression string to scan.

Returns:

An unordered set of variable name strings. Contains only expr itself when parsing fails.

std::string rename_expression(const std::string &expr, const std::unordered_map<std::string, std::string> &mapping)#

Renames variables in expr according to mapping.

Also converts Max(a, b) calls to the a^b xor form before renaming. The result has all spaces removed (matching the Python reference output).

std::string r = rename_expression("s52 + seq_length", {{"s52", "B"}});
// r == "B+seq_length"

std::string m = rename_expression("Max(s10, s3)", {{"s10", "E"}, {"s3", "D"}});
// m == "E^D"  (Max is rewritten to ^ before renaming)

Parameters:
  • expr – The expression string to rename.

  • mapping – A map from old variable name to new variable name.

Throws:

std::runtime_error – if expr cannot be parsed.

Returns:

The renamed expression string (no spaces).

std::string rename_dynamic_expression(const std::string &expression, const std::unordered_map<std::string, std::string> &replacements)#

Renames variables in expression using replacements, then simplifies.

Applies the following pipeline in order:

  1. Parse expression.

  2. Rewrite Max(a, b)a ^ b.

  3. Apply the rename mapping.

  4. Apply SimpleSimplifyTransformer.

  5. Unparse and strip spaces.

Returns expression unchanged if it has a syntax error.

std::string r = rename_dynamic_expression("s9+seq_length",
    {{"s9", "cache_length"}, {"seq_length", "seq_length"}});
// r == "cache_length+seq_length"

Parameters:
  • expression – The expression string to transform.

  • replacements – A map from old variable name to new variable name.

Returns:

The renamed and simplified expression string (no spaces), or expression unchanged on parse failure.

std::map<std::string, std::string> rename_dynamic_dimensions(const std::map<std::string, std::unordered_set<std::string>> &constraints, const std::unordered_set<std::string> &original, const std::string &ban_prefix = "DYN")#

Renames dynamic shape dimensions from internal names to user-visible ones.

Frameworks such as torch.export.export produce many internal dimension names (e.g. s0, s1, …) for dynamic shapes. This function replaces them with the canonical names supplied by the user via original.

The algorithm iterates over constraints; for each entry it finds the intersection of the equivalent-name set with original, picks the lexicographically smallest match as the canonical name, and propagates it to all aliases — unless the name starts with ban_prefix.

std::map<std::string, std::unordered_set<std::string>> constraints = {
    {"s0", {"batch", "s12"}},
    {"s12", {"batch", "s0"}},
};
std::unordered_set<std::string> original = {"batch"};
auto renamed = rename_dynamic_dimensions(constraints, original);
// renamed["s0"] == "batch"
// renamed["s12"] == "batch"

Parameters:
  • constraints – A map from each dimension name to the set of all dimension names that are known to be equal to it (i.e. the equivalence class).

  • original – The set of user-visible (preferred) dimension names.

  • ban_prefix – Names starting with this prefix are never selected as the canonical replacement (default: "DYN").

Returns:

A map {internal_name canonical_name} covering all names in original (mapped to themselves) plus every name in constraints that was successfully resolved.

std::string dim_to_string(const DimType &d)#

Returns a string representation of d.

Converts an int64_t to its decimal string; returns the std::string variant unchanged.

Parameters:

d – The dimension value to convert.

Returns:

Decimal string for integer dimensions; the symbol string otherwise.

DimType dim_mul(const DimType &a, const DimType &b)#

Multiplies two dimensions.

Returns a * b as an int64_t when both operands are integers. Otherwise builds the expression "(a)*(b)" and simplifies it symbolically.

dim_mul(DimType{int64_t{3}}, DimType{int64_t{4}}) == DimType{int64_t{12}};
// dim_mul("n", 2) returns a string containing "n" and "2"

Parameters:
  • a – The first dimension (integer or symbolic string).

  • b – The second dimension (integer or symbolic string).

Returns:

The product as an integer when both are concrete, or as a simplified string otherwise.

DimType dim_multi_mul(const std::vector<DimType> &args)#

Multiplies a sequence of dimensions.

Computes the product of all elements in args. If every element is an int64_t the result is an exact integer product; otherwise the expression "(a0)*(a1)*..." is built and simplified symbolically.

dim_multi_mul({DimType{int64_t{2}}, DimType{int64_t{3}}, DimType{int64_t{4}}})
    == DimType{int64_t{24}};

Parameters:

args – Non-empty vector of dimensions. Returns int64_t{1} for an empty vector.

Returns:

The product as an integer when all operands are concrete, or as a simplified string otherwise.

DimType dim_add(const DimType &a, const DimType &b)#

Adds two dimensions.

Returns a + b as an int64_t when both are integers; otherwise builds "(a)+(b)" and simplifies.

dim_add(DimType{int64_t{3}}, DimType{int64_t{4}}) == DimType{int64_t{7}};

Parameters:
  • a – The first dimension.

  • b – The second dimension.

Returns:

The sum.

DimType dim_sub(const DimType &a, const DimType &b)#

Subtracts b from a.

Returns a - b as an int64_t when both are integers; otherwise builds "(a)-(b)" and simplifies.

dim_sub(DimType{int64_t{10}}, DimType{int64_t{3}}) == DimType{int64_t{7}};

Parameters:
  • a – The minuend.

  • b – The subtrahend.

Returns:

The difference.

DimType dim_div(const DimType &a, const DimType &b)#

Floor-divides a by b.

Assumes both values are non-negative (as is typical for ONNX shape dimensions). Returns a // b as an int64_t when both are integers; otherwise builds "(a)//(b)" and simplifies.

dim_div(DimType{int64_t{7}}, DimType{int64_t{2}}) == DimType{int64_t{3}};
dim_div(DimType{std::string{"2*n"}}, DimType{int64_t{2}}) == DimType{std::string{"n"}};

Parameters:
  • a – The dividend.

  • b – The divisor.

Returns:

The floor-division result.

DimType dim_mod(const DimType &a, const DimType &b)#

Computes a modulo b.

Returns a % b as an int64_t when both are integers; otherwise builds "(a)%(b)" and simplifies.

dim_mod(DimType{int64_t{10}}, DimType{int64_t{3}}) == DimType{int64_t{1}};

Parameters:
  • a – The dividend.

  • b – The divisor.

Returns:

The remainder.

DimType dim_max(const DimType &a, const DimType &b)#

Returns the maximum of a and b.

Returns max(a, b) as an int64_t when both are integers; otherwise builds "(a)^(b)" (the xor encoding of max) and simplifies.

dim_max(DimType{int64_t{7}}, DimType{int64_t{3}}) == DimType{int64_t{7}};
// dim_max("n", "n") simplifies to DimType{std::string{"n"}} (x^x → x)

Parameters:
  • a – The first dimension.

  • b – The second dimension.

Returns:

The maximum of the two dimensions.

DimType dim_min(const DimType &a, const DimType &b)#

Returns the minimum of a and b.

Returns min(a, b) as an int64_t when both are integers; otherwise builds "(a)&(b)" (the ampersand encoding of min) and simplifies.

dim_min(DimType{int64_t{2}}, DimType{int64_t{9}}) == DimType{int64_t{2}};

Parameters:
  • a – The first dimension.

  • b – The second dimension.

Returns:

The minimum of the two dimensions.

struct BinOp : public onnx_light::onnx_optim::expressions::Node#
#include <expressions.h>

Interior node representing a binary arithmetic operation.

auto b = std::make_unique<BinOp>(
    std::make_unique<Name>("a"),
    BinOpKind::Add,
    std::make_unique<Constant>(1));
// unparse(*b) == "a+1"

Public Functions

inline BinOp(NodePtr l, BinOpKind o, NodePtr r)#

Constructs a BinOp node.

Parameters:
  • l – Left operand (ownership transferred).

  • o – The operator kind.

  • r – Right operand (ownership transferred).

inline virtual NodePtr clone() const override#

Returns a deep copy of this node and its entire sub-tree.

Returns:

An owning pointer to the cloned node.

Public Members

NodePtr left#

Left-hand operand.

BinOpKind op#

The binary operator.

NodePtr right#

Right-hand operand.

struct Call : public onnx_light::onnx_optim::expressions::Node#
#include <expressions.h>

Interior node representing a function call (e.g. CeilToInt, Max).

The only function calls understood by evaluate_expression are CeilToInt(n, div), which computes ceiling division. Max(a, b) is syntactic sugar that MaxToXorTransformer rewrites to a ^ b before evaluation.

// After MaxToXorTransformer, Max(a, b) becomes BinOp(a, BitXor, b).
// CeilToInt(n, 2) is evaluated as (n % 2 == 0) ? n/2 : n/2+1.

Public Functions

inline Call(std::string f, std::vector<NodePtr> a)#

Constructs a Call node.

Parameters:
  • f – The function name.

  • a – The argument list (ownership transferred).

inline virtual NodePtr clone() const override#

Returns a deep copy of this node and its entire sub-tree.

Returns:

An owning pointer to the cloned node.

Public Members

std::string func#

The function name (e.g. "CeilToInt", "Max").

std::vector<NodePtr> args#

Positional arguments (ownership held).

struct Constant : public onnx_light::onnx_optim::expressions::Node#
#include <expressions.h>

Leaf node representing a signed 64-bit integer constant.

auto c = std::make_unique<Constant>(42);
// unparse(*c) == "42"

Public Functions

inline explicit Constant(int64_t v)#

Constructs a Constant node.

Parameters:

v – The integer value.

inline virtual NodePtr clone() const override#

Returns a deep copy of this node and its entire sub-tree.

Returns:

An owning pointer to the cloned node.

Public Members

int64_t value#

The integer value of this constant.

struct Name : public onnx_light::onnx_optim::expressions::Node#
#include <expressions.h>

Leaf node representing a symbolic variable reference.

auto n = std::make_unique<Name>("batch");
// unparse(*n) == "batch"

Public Functions

inline explicit Name(std::string s)#

Constructs a Name node.

Parameters:

s – The variable name string.

inline virtual NodePtr clone() const override#

Returns a deep copy of this node and its entire sub-tree.

Returns:

An owning pointer to the cloned node.

Public Members

std::string id#

The variable name (e.g. "batch", "seq_length").

struct Node#
#include <expressions.h>

Abstract base class for all expression AST nodes.

Every concrete node type derives from Node and overrides clone() to produce a deep copy. Nodes are always heap-allocated and owned by NodePtr.

Subclassed by onnx_light::onnx_optim::expressions::BinOp, onnx_light::onnx_optim::expressions::Call, onnx_light::onnx_optim::expressions::Constant, onnx_light::onnx_optim::expressions::Name, onnx_light::onnx_optim::expressions::UnaryOp

Public Functions

virtual ~Node() = default#

Destroys the node and recursively frees all child nodes.

virtual NodePtr clone() const = 0#

Returns a deep copy of this node and its entire sub-tree.

Returns:

An owning pointer to the cloned node.

struct UnaryOp : public onnx_light::onnx_optim::expressions::Node#
#include <expressions.h>

Interior node representing a unary arithmetic operation.

auto u = std::make_unique<UnaryOp>(UnaryOpKind::USub,
                                   std::make_unique<Name>("x"));
// unparse(*u) == "-x"

Public Functions

inline UnaryOp(UnaryOpKind o, NodePtr n)#

Constructs a UnaryOp node.

Parameters:
  • o – The operator kind.

  • n – The operand (ownership transferred).

inline virtual NodePtr clone() const override#

Returns a deep copy of this node and its entire sub-tree.

Returns:

An owning pointer to the cloned node.

Public Members

UnaryOpKind op#

The unary operator.

NodePtr operand#

The operand.