shape_broadcast.h#
Shared helpers for shape-inference functions of binary ONNX operators that support numpy-style (multidirectional) broadcasting.
-
namespace ONNX_LIGHT_NAMESPACE
-
namespace onnx_optim
-
namespace shapes
Functions
-
OptimShape BroadcastShapes(const OptimShape &a, const OptimShape &b)#
Computes the broadcast result shape of two :cpp:class:
OptimShapeoperands following the ONNX (numpy-style) multidirectional broadcasting rules.The shapes are right-aligned and the dimensions are paired starting from the trailing axis (missing leading dimensions are treated as
1). For each paired dimension(d_a, d_b)the resulting dimension is computed as follows:if both are concrete integers: standard broadcasting rules are enforced — equal dimensions or a dimension of
1paired with anything are accepted; mismatching non-unit integers throwstd::invalid_argument;if either is the integer
1: the result is the other dimension;if both are equal (same integer or same symbolic expression): the result is that dimension;
if one is a concrete integer (different from
1) and the other is symbolic: the concrete integer wins (it is the only value compatible with broadcasting against itself);if both are different symbolic expressions: a fresh symbolic dimension is produced, encoding the broadcast as
"broadcast(<a>, <b>)"so that the symbolic information is preserved.
- Throws:
std::invalid_argument – when two concrete integer dimensions are incompatible under broadcasting.
-
void ComputeShapeBinaryBroadcast(ShapesContext &ctx, const NodeProto &node, const char *input_a, const char *input_b, const char *expected_op_type, TensorType output_dtype)#
Generic shape-inference helper for binary ONNX operators that support numpy-style broadcasting. Reads the descriptors of
input_aandinput_bfromctx, computes the broadcast output shape and stores a new entry undernode.output(0)with the givenoutput_dtype.The helper enforces the following preconditions:
node.op_type()must equalexpected_op_type;nodemust declare at least one output;both
input_aandinput_bmust be present inctx.
-
OptimShape BroadcastShapes(const OptimShape &a, const OptimShape &b)#
-
namespace shapes
-
namespace onnx_optim