shape_preview.h#
Shape-inference functions for ONNX operators in the ai.onnx.preview domain.
-
namespace ONNX_LIGHT_NAMESPACE
-
namespace onnx_optim
-
namespace shapes
-
namespace preview#
Functions
-
void ComputeShapeFlexAttention(ShapesContext &ctx, const NodeProto &node, const char *q, const char *k, const char *v)#
Computes the output :cpp:class:
OptimTensorof aFlexAttentionnode and stores it inctx.FlexAttentionexpects three rank-4 inputs:Qwith shape(batch_size, q_num_heads, q_seq_len, head_size);Kwith shape(batch_size, kv_num_heads, kv_seq_len, head_size);Vwith shape(batch_size, kv_num_heads, kv_seq_len, v_head_size).
The output
Yhas the same element type asQand shape(batch_size, q_num_heads, q_seq_len, v_head_size). The function validates thatQ/K/Vshare the same dtype, thatKandVshare the same head count and sequence length, thatQandKshare the same embedding dimension, and thatq_num_headsis a multiple ofkv_num_heads(Grouped Query Attention).Symbolic dimensions propagate symbolically: any constraint that cannot be verified because one side is symbolic is skipped, as in the upstream
FlexAttentionShapeInferenceinonnx_lib.- Parameters:
ctx – In/out context. Must already contain entries for
q,k, andv; on return it also contains an entry fornode.output(0).node – The
FlexAttentionNodeProtowhose output should be described.node.op_type()must be"FlexAttention"andnodemust declare at least one output.q – Name of the query input value to read from
ctx.k – Name of the key input value to read from
ctx.v – Name of the value input value to read from
ctx.
- Throws:
Variables
-
constexpr const char *kOnnxPreviewDomain = "ai.onnx.preview"#
The ai.onnx.preview operator domain string.
-
void ComputeShapeFlexAttention(ShapesContext &ctx, const NodeProto &node, const char *q, const char *k, const char *v)#
-
namespace preview#
-
namespace shapes
-
namespace onnx_optim