.. _l-howto-shape-inference: Shape inference =============== This page answers common *"how do I…"* questions for running shape inference on an ONNX model with :class:`~yobx.xshape.BasicShapeBuilder`. ---- How to run shape inference on an ONNX model -------------------------------------------- Instantiate :class:`~yobx.xshape.BasicShapeBuilder` and call :meth:`~yobx.xshape.shape_builder_impl.BasicShapeBuilder.run_model`. After the call, query each result tensor by name with :meth:`~yobx.xshape.ShapeBuilder.get_shape`. Unlike :func:`onnx.shape_inference.infer_shapes`, which can only propagate shapes for statically-known integer dimensions, :class:`~yobx.xshape.BasicShapeBuilder` keeps each dimension as a symbolic arithmetic expression so that output shapes are expressed in terms of the input dimension names (e.g. ``batch``, ``seq``). .. runpython:: :showcode: import numpy as np import onnx import onnx.helper as oh import onnx.numpy_helper as onh from yobx.xshape import BasicShapeBuilder TFLOAT = onnx.TensorProto.FLOAT TINT64 = onnx.TensorProto.INT64 model = oh.make_model( oh.make_graph( [ oh.make_node("Add", ["X", "Y"], ["added"]), oh.make_node("Concat", ["added", "X"], ["concat_out"], axis=2), oh.make_node("Reshape", ["concat_out", "reshape_shape"], ["Z"]), ], "add_concat_reshape", [ oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", "d_model"]), oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", "d_model"]), ], [oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])], [onh.from_array(np.array([0, 0, -1], dtype=np.int64), name="reshape_shape")], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) builder = BasicShapeBuilder() builder.run_model(model) for name in ["X", "Y", "added", "concat_out", "Z"]: print(f" {name:15s} shape={builder.get_shape(name)}") ---- How to compare shape inference approaches ------------------------------------------ Three tools are commonly used for ONNX shape inference: 1. :func:`onnx.shape_inference.infer_shapes` — the standard ONNX tool. It propagates both integer and named symbolic dimensions (``dim_param``) through the graph. When an output dimension cannot be determined from its inputs — for example the result of concatenating two ``d_model``-wide axes — it assigns a freshly generated symbol (e.g. ``unk__0``) rather than computing the arithmetic relationship ``2*d_model``. Truly data-dependent dimensions (such as the number of non-zero elements) remain ``None``. 2. `onnx-shape-inference `_ — a third-party package that uses `SymPy `_ to track dimension expressions on the :pypi:`onnx-ir` representation of the model. 3. :class:`~yobx.xshape.BasicShapeBuilder` — the built-in yobx tool. It keeps dimensions as symbolic arithmetic expressions and evaluates constant-shape tensors (such as the ``shape`` input of ``Reshape``) to propagate information through the graph. The table below illustrates the difference on a model that contains dynamic dimensions: .. runpython:: :showcode: import numpy as np import pandas import onnx import onnx_ir as ir import onnx.helper as oh import onnx.numpy_helper as onh from onnx_shape_inference import infer_symbolic_shapes from yobx.xshape import BasicShapeBuilder TFLOAT = onnx.TensorProto.FLOAT model = oh.make_model( oh.make_graph( [ oh.make_node("Add", ["X", "Y"], ["added"]), oh.make_node("Concat", ["added", "X"], ["concat_out"], axis=2), oh.make_node("Reshape", ["concat_out", "reshape_shape"], ["Z"]), ], "add_concat_reshape", [ oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", "d_model"]), oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", "d_model"]), ], [oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])], [onh.from_array(np.array([0, 0, -1], dtype=np.int64), name="reshape_shape")], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) # onnx.shape_inference.infer_shapes inferred = onnx.shape_inference.infer_shapes(model) onnx_shapes = {} for vi in [*inferred.graph.input, *inferred.graph.value_info, *inferred.graph.output]: t = vi.type.tensor_type if t.HasField("shape"): onnx_shapes[vi.name] = tuple( d.dim_param if d.dim_param else (d.dim_value if d.dim_value else None) for d in t.shape.dim ) else: onnx_shapes[vi.name] = "unknown" # onnx-shape-inference ir_model = ir.serde.deserialize_model(model) ir_model = infer_symbolic_shapes(ir_model) ir_shapes = {} for v in ir_model.graph.inputs: ir_shapes[v.name] = str(v.shape) for node in ir_model.graph: for out in node.outputs: ir_shapes[out.name] = str(out.shape) # BasicShapeBuilder builder = BasicShapeBuilder() builder.run_model(model) names = ["X", "Y", "added", "concat_out", "Z"] rows = [] for name in names: rows.append({ "name": name, "onnx": str(onnx_shapes.get(name, "unknown")), "onnx_ir": str(ir_shapes.get(name, "unknown")), "basic": str(builder.get_shape(name)), }) print(pandas.DataFrame(rows).set_index("name").to_string()) **Key observations:** - For ``added = Add(X, Y)``, all three tools correctly propagate the input shape ``('batch', 'seq', 'd_model')`` to the output. - For ``concat_out = Concat(added, X, axis=2)``, :func:`onnx.shape_inference.infer_shapes` cannot compute ``d_model + d_model`` symbolically and assigns a fresh placeholder (``unk__0``), while :class:`~yobx.xshape.BasicShapeBuilder` derives ``2*d_model`` exactly. ``onnx-shape-inference`` reaches the same conclusion as :class:`~yobx.xshape.BasicShapeBuilder` here. - For the ``Reshape`` output ``Z``, the constant ``reshape_shape`` tensor (``[0, 0, -1]``) allows :class:`~yobx.xshape.BasicShapeBuilder` to evaluate the flattening and express ``Z`` as ``('batch', 'seq', '2*d_model')``. ``onnx-shape-inference`` may assign a fresh symbol to ``Z`` because it does not always evaluate constant-shape tensors, and ``onnx.shape_inference`` likewise cannot resolve the result. ---- How to evaluate symbolic shapes with concrete values ----------------------------------------------------- Once the model has been analysed, call :meth:`~yobx.xshape.ShapeBuilder.evaluate_shape` with a dictionary of ``{dim_name: int}`` values to substitute the symbolic variables and obtain concrete integer shapes. .. runpython:: :showcode: import numpy as np import onnx import onnx.helper as oh import onnx.numpy_helper as onh from yobx.xshape import BasicShapeBuilder TFLOAT = onnx.TensorProto.FLOAT model = oh.make_model( oh.make_graph( [ oh.make_node("Add", ["X", "Y"], ["added"]), oh.make_node("Concat", ["added", "X"], ["concat_out"], axis=2), oh.make_node("Reshape", ["concat_out", "reshape_shape"], ["Z"]), ], "add_concat_reshape", [ oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", "d_model"]), oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", "d_model"]), ], [oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])], [onh.from_array(np.array([0, 0, -1], dtype=np.int64), name="reshape_shape")], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) builder = BasicShapeBuilder() builder.run_model(model) context = {"batch": 2, "seq": 5, "d_model": 8} for name in ["X", "Y", "added", "concat_out", "Z"]: sym = builder.get_shape(name) concrete = builder.evaluate_shape(name, context) print(f" {name:15s} symbolic={sym!s:30s} concrete={concrete}") ---- How to estimate the computational cost of a model -------------------------------------------------- Pass ``inference=InferenceMode.COST`` to :meth:`~yobx.xshape.shape_builder_impl.BasicShapeBuilder.run_model`. The method returns a list of ``(op_type, flops, node)`` triples where *flops* is either an integer (static shapes), a symbolic string expression (dynamic shapes), or ``None`` (unsupported operator or unknown shapes). Cost is expressed in *floating-point operations* (FLOPs). .. runpython:: :showcode: import numpy as np import onnx import onnx.helper as oh from yobx.xshape import BasicShapeBuilder, InferenceMode TFLOAT = onnx.TensorProto.FLOAT model = oh.make_model( oh.make_graph( [ oh.make_node("MatMul", ["A", "B"], ["C"]), oh.make_node("Relu", ["C"], ["out"]), ], "matmul_relu", [ oh.make_tensor_value_info("A", TFLOAT, ["batch", "M", "K"]), oh.make_tensor_value_info("B", TFLOAT, ["batch", "K", "N"]), ], [oh.make_tensor_value_info("out", TFLOAT, None)], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) builder = BasicShapeBuilder() cost_list = builder.run_model(model, inference=InferenceMode.COST) print("Symbolic FLOPs per node:") for op_type, flops, _ in cost_list: print(f" {op_type:<12s} {flops}") To substitute concrete dimension values and obtain integer FLOPs counts, use :meth:`~yobx.xshape.shape_builder_impl.BasicShapeBuilder.evaluate_cost_with_true_inputs`: .. runpython:: :showcode: import numpy as np import onnx import onnx.helper as oh from yobx.xshape import BasicShapeBuilder, InferenceMode TFLOAT = onnx.TensorProto.FLOAT model = oh.make_model( oh.make_graph( [ oh.make_node("MatMul", ["A", "B"], ["C"]), oh.make_node("Relu", ["C"], ["out"]), ], "matmul_relu", [ oh.make_tensor_value_info("A", TFLOAT, ["batch", "M", "K"]), oh.make_tensor_value_info("B", TFLOAT, ["batch", "K", "N"]), ], [oh.make_tensor_value_info("out", TFLOAT, None)], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) builder = BasicShapeBuilder() cost_list = builder.run_model(model, inference=InferenceMode.COST) rng = np.random.default_rng(0) feeds = { "A": rng.standard_normal((4, 32, 64)).astype(np.float32), "B": rng.standard_normal((4, 64, 16)).astype(np.float32), } concrete = builder.evaluate_cost_with_true_inputs(feeds, cost_list) total = 0 print("Concrete FLOPs per node:") for op_type, flops, _ in concrete: total += flops or 0 print(f" {op_type:<12s} {flops:>12,}") print(f" {'TOTAL':<12s} {total:>12,}") **Cost considerations:** Running :class:`~yobx.xshape.BasicShapeBuilder` itself introduces overhead because it must interpret every node in the graph and evaluate any constant sub-expressions it finds. For very large models (thousands of nodes) this can take noticeably longer than :func:`onnx.shape_inference.infer_shapes`, which relies on a compiled C++ backend. :class:`~yobx.xshape.BasicShapeBuilder` is best used offline (during model export or optimisation) rather than in a hot inference path. ---- How to work with constraints from named output dimensions --------------------------------------------------------- Some operators — such as ``NonZero`` — introduce a *data-dependent* dimension whose size cannot be determined from shapes alone. :class:`~yobx.xshape.BasicShapeBuilder` assigns an internal placeholder name (e.g. ``NEWDIM_nonzero_0``) to such a dimension. When the graph output is annotated with **named dimensions**, the builder detects the mismatch between the computed placeholder and the user-supplied name, registers the *constraint* ``NEWDIM_nonzero_0 = nnz``, and renames the placeholder throughout the graph. Inspect registered constraints with :meth:`~yobx.xshape.ShapeBuilder.get_registered_constraints`. .. runpython:: :showcode: import onnx import onnx.helper as oh from yobx.xshape import BasicShapeBuilder TFLOAT = onnx.TensorProto.FLOAT TINT64 = onnx.TensorProto.INT64 nodes = [ oh.make_node("Abs", ["X"], ["abs_out"]), oh.make_node("NonZero", ["abs_out"], ["nz"]), oh.make_node("Transpose", ["nz"], ["transposed_nz"]), ] inputs = [oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq"])] # --- anonymous output shapes: placeholder is kept as-is --- model_anon = oh.make_model( oh.make_graph( nodes, "nonzero_anon", inputs, [oh.make_tensor_value_info("transposed_nz", TINT64, [None, None])], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) # --- named output shapes: constraint is registered --- model_named = oh.make_model( oh.make_graph( nodes, "nonzero_named", inputs, [oh.make_tensor_value_info("transposed_nz", TINT64, ["nnz", "rank"])], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) b_anon = BasicShapeBuilder() b_anon.run_model(model_anon) print("anonymous — nz shape :", b_anon.get_shape("nz")) print("anonymous — transposed shape :", b_anon.get_shape("transposed_nz")) print("anonymous — constraints :", b_anon.get_registered_constraints()) print() b_named = BasicShapeBuilder() b_named.run_model(model_named) print("named — nz shape :", b_named.get_shape("nz")) print("named — transposed shape :", b_named.get_shape("transposed_nz")) print("named — constraints :", b_named.get_registered_constraints()) **When to use named output dimensions:** Provide named dimensions in the graph output annotations whenever the graph contains data-dependent operators and you want downstream code to be able to reference those dimensions by a stable name. The constraint ``NEWDIM_nonzero_0 = nnz`` acts as a *rename directive* that propagates the user-visible name throughout the entire symbolic shape graph, making the output of every subsequent node easier to interpret. .. seealso:: :ref:`l-design-shape` — design document describing the algorithm behind :class:`~yobx.xshape.BasicShapeBuilder` in detail. :ref:`l-design-cost` — explanation of how FLOPs are counted for each operator type. :ref:`l-plot-computed-shapes` — gallery example comparing all three shape-inference tools on a larger model. :ref:`l-plot-symbolic-cost` — gallery example estimating the cost of an attention model before and after optimisation.