shape_tensor.h#
Shape-inference functions for ONNX operators in the tensor family.
-
namespace ONNX_LIGHT_NAMESPACE
-
namespace onnx_optim
-
namespace shapes
-
namespace tensor#
Functions
-
void ComputeShapeConcat(ShapesContext &ctx, const NodeProto &node)#
Computes the output :cpp:class:
OptimTensorof aConcatnode and stores it inctx.Concatconcatenates a variadic list of input tensors along the axis specified by theaxisattribute. All inputs must share the same rank and the same dimension sizes on every axis other than the concatenation axis. The output dtype always matches the dtype of the first input (type constraintT); the output shape is:on the concatenation axis: the sum of all input dimensions when every input dimension on that axis is a concrete integer; otherwise a fresh symbolic dimension;
on every other axis: the merged dimension between all inputs (concrete dimensions must match across inputs, otherwise an exception is thrown; a concrete value overrides a symbolic one).
The
axisattribute can be negative, in which case it is interpreted asaxis + rank. When the attribute is missing the default of1(the opset 1 default) is used.- Parameters:
ctx – In/out context. Must already contain an entry for every name in
node.input. On return it also contains an entry fornode.output(0).node – The
ConcatNodeProtowhose output should be described.node.op_type()must be"Concat",nodemust declare at least one input and at least one output.
- Throws:
-
void ComputeShapeCast(ShapesContext &ctx, const NodeProto &node)#
Computes the output :cpp:class:
OptimTensorof aCastnode and stores it inctx.Castproduces an output whose shape is identical to the shape of its single input and whose element type is given by the required integer attributeto(aTensorProto::DataTypevalue). The other optional attributes (saturate,round_mode) do not affect the output shape or dtype and are therefore not inspected by this function.- Parameters:
ctx – In/out context. Must already contain an entry for
node.input(0). On return it also contains an entry fornode.output(0).node – The
CastNodeProtowhose output should be described.node.op_type()must be"Cast",nodemust declare at least one input and at least one output and must carry the requiredtoattribute.
- Throws:
-
void ComputeShapeReshape(ShapesContext &ctx, const NodeProto &node)#
Computes the output :cpp:class:
OptimTensorof aReshapenode and stores it inctx.Reshapetakes adatatensor and a 1-D int64shapetensor whose values describe the desired output shape. The output dtype is the dtype ofdata(type constraintT). The output shape is derived element-by-element from the target shape:a positive value is used verbatim;
0means “copy from the input ``data`` shape at the same
index”, unless the
allowzeroattribute is set to1(in which case0is honoured literally);exactly one
-1is allowed; the corresponding dimension is inferred so that the total number of elements is preserved (whendatais fully known and the other dims are concrete);symbolic target dims are forwarded as symbolic output dims.
Shape values are read from the
shapeinput’s :cpp:func:OptimTensor::ValueAsShapeannotation (populated for small constants, e.g. by :cpp:func:ComputeShapeConstant). When that annotation is missing the output rank is taken from the static shape of theshapeinput (its single dimension, when concrete) and every output dim is left symbolic. When the rank itself is unknown the output is left as a fully-symbolic rank-1 tensor.- Parameters:
ctx – In/out context. Must contain entries for
node.input(0)(data) andnode.input(1)(shape). On return it also contains an entry fornode.output(0).node – The
ReshapeNodeProtowhose output should be described.node.op_type()must be"Reshape",nodemust declare two inputs and at least one output.
- Throws:
std::invalid_argument – if
node.op_type()is not"Reshape", ifnodehas fewer than two inputs or no output, if the target shape contains more than one-1, contains a value strictly less than-1, if a0entry (withallowzero == 0) refers to a position outside the input rank, or if a-1cannot be reconciled with the input’s element count.std::out_of_range – if any input name is missing from
ctx.
-
void ComputeShapeConcat(ShapesContext &ctx, const NodeProto &node)#
-
namespace tensor#
-
namespace shapes
-
namespace onnx_optim