shape_inference.h#
Top-level shape-inference helpers that dispatch to the per-operator ComputeShape* functions of onnx_optim.
The per-operator functions (for example :cpp:func:onnx_optim::shapes::math::ComputeShapeAbs) each take their inputs by name. The helpers in this file walk a single NodeProto or a topologically-sorted sequence of NodeProto (typically GraphProto::node()), look up the op type and forward the call to the matching ComputeShape* implementation.
-
namespace ONNX_LIGHT_NAMESPACE
-
namespace onnx_optim
-
namespace shapes
Functions
-
void ComputeShapeNode(ShapesContext &ctx, const NodeProto &node)#
Dispatches a single
NodePrototo the matching per-operatorComputeShape*function and stores the resulting output :cpp:class:OptimTensordescriptors inctx.The dispatch table is keyed on
node.op_type(). Only operators thatonnx_optimcurrently knows about are accepted; unsupported op types throwstd::invalid_argument. The node’s input descriptors are read fromctxby name (so every input must already be present), and the output descriptors are inserted intoctxunder the names declared bynode.output(i).The domain of
nodeis checked: only nodes belonging to the default ONNX domain (empty string or"ai.onnx") are supported.- Parameters:
ctx – In/out context. Must already contain entries for every input referenced by
node; on return it also contains entries for every output declared bynode.node – The
NodeProtowhose output shapes should be computed.
- Throws:
-
void CheckInputsAvailable(const ShapesContext &ctx, const NodeProto &node)#
Throws
std::invalid_argumentif any non-empty input name declared bynodeis missing fromctx. Empty input names — used by ONNX to denote optional inputs that are not provided — are skipped.
-
void CheckOutputsNotAvailable(const ShapesContext &ctx, const NodeProto &node)#
Throws
std::invalid_argumentif any non-empty output name declared bynodealready has an entry inctx. Empty output names — used by ONNX to denote optional outputs that are not produced — are skipped.This is intended as a precondition check for shape-inference passes that build up
ctxincrementally and should never overwrite an existing descriptor.
-
void ComputeShapes(ShapesContext &ctx, const utils::RepeatedProtoField<NodeProto> &nodes)#
Runs :cpp:func:
ComputeShapeNodeon every node ofnodesin order. The sequence must be topologically sorted with respect to data dependencies (as required by the ONNX specification forGraphProto::node), so that every input of a node has already been described inctx(either as a pre-existing graph input/initializer or as the output of an earlier node in the sequence) by the time the node is processed.- Parameters:
ctx – In/out context. On entry it must contain descriptors for every graph input and initializer referenced by
nodes; on return it additionally contains descriptors for every output of every node innodes.nodes – The list of nodes to process, in topological order.
- Throws:
-
void ComputeShapeGraph(ShapesContext &ctx, const GraphProto &graph)#
Seeds
ctxfrom the initializers and inputs ofgraphand then runs :cpp:func:ComputeShapeson its nodes.For every entry in
graph.initializer()an :cpp:class:OptimTensordescribing the initializer’s element type and shape is inserted inctx(small 1-D integer initializers also get aValueAsShapeannotation derived from their content, mirroring :cpp:func:ComputeShapeConstant).For every entry in
graph.input()whose name is not already present inctx(initializers take precedence), an :cpp:class:OptimTensordescribing the value’s tensor type and shape is inserted. Inputs declared as sequence/map/optional/sparse types are skipped (OptimTensordoes not model these). Missing or symbolic dimensions are preserved as :cpp:class:OptimDimexpressions:dim_valuebecomes a concrete integer dimension,dim_parambecomes a symbolic dimension with the same name, and an unset dimension becomes a fresh"?"symbolic dimension.- Parameters:
ctx – In/out context. On entry it may already contain outer-scope entries (for sub-graphs). On return it additionally contains descriptors for every graph initializer, every graph input (when not already present) and every node output.
graph – The graph whose initializers, inputs and nodes are processed in topological order.
- Throws:
std::invalid_argument – propagated from :cpp:func:
ComputeShapeNode.
-
void ComputeShapeModel(ShapesContext &ctx, const ModelProto &model)#
Runs shape inference on
model.graph().Records every
(domain, version)pair inmodel.opset_import()inctxvia :cpp:func:ShapesContext::SetOpsetVersionand then delegates to :cpp:func:ComputeShapeGraph. Any opset entry already recorded inctxis overwritten so that the values from theModelPrototake precedence.- Parameters:
ctx – In/out context. On return it contains the model’s opset versions, the graph initializers, the graph inputs and every intermediate value computed by the graph nodes.
model – The model whose main graph is processed.
- Throws:
std::invalid_argument – when
modelhas no graph or when shape inference of the graph rejects a node.
-
void ApplyInferredShapesToGraph(const ShapesContext &ctx, GraphProto &graph)#
Writes the shape and element-type descriptors stored in
ctxback intographso that the inferred information persists in the proto representation.For every output declared by
graph.output()whose name is present inctx(and whoseOptimTensorhas a known dtype), the matchingValueInfoProtois updated to carry the inferred tensor element type andTensorShapeProto. Concrete integer dimensions becomedim_valueentries and symbolic dimensions becomedim_paramentries.For every other name in
ctx.Tensors()(intermediate results) a freshValueInfoProtois appended tograph.value_info(). Names that already correspond to a graph input, initializer or existingvalue_infoentry are skipped so the function never overwrites authoritative type information already present in the proto.Entries whose
OptimTensorhas :cpp:enumerator:TensorType::kUndefinedelement type are skipped — there is no validTensorProto::DataTypeto write for them.- Parameters:
ctx – Context populated by a previous call to :cpp:func:
ComputeShapeGraph/ :cpp:func:ComputeShapeModel.graph – In/out graph whose
outputandvalue_infoentries are updated in place.
-
void ApplyInferredShapesToModel(const ShapesContext &ctx, ModelProto &model)#
Writes the shape and element-type descriptors stored in
ctxback intomodel.graph()by delegating to :cpp:func:ApplyInferredShapesToGraph.- Parameters:
ctx – Context populated by a previous call to :cpp:func:
ComputeShapeModel.model – In/out model whose main graph is updated in place.
- Throws:
std::invalid_argument – when
modelhas no graph.
-
void InferShapesModel(ModelProto &model)#
Convenience helper: runs :cpp:func:
ComputeShapeModelonmodeland then :cpp:func:ApplyInferredShapesToModel, mutatingmodelin place so that itsgraph.output()andgraph.value_info()carry the inferred shapes.- Parameters:
model – In/out model on which shape inference is run and whose proto is updated with the inferred results.
- Throws:
std::invalid_argument – when
modelhas no graph or when shape inference of the graph rejects a node.
-
void ComputeShapeNode(ShapesContext &ctx, const NodeProto &node)#
-
namespace shapes
-
namespace onnx_optim