include_preview_kernels.h#
-
namespace ONNX_LIGHT_NAMESPACE
-
namespace onnx_backend_test
-
namespace kernel
-
class FlexAttention#
- #include <include_preview_kernels.h>
Reference implementation of
ai.onnx.preview::FlexAttention(v1).Computes
Y = Softmax((Q @ K^T) * scale, axis=-1) @ Vover rank-4 (batched, multi-head) FLOAT inputs. Supports Grouped Query Attention (GQA): whenq_num_heads != kv_num_headseach K/V head is shared by a contiguous group of query heads, i.e. query headhuses K/V headfloor(h / (q_num_heads / kv_num_heads));q_num_headsmust be a multiple ofkv_num_heads.The optional
score_modandprob_modmodifier subgraphs of the upstream operator are not modeled by this reference kernel — it implements the un-modified baseline that backends are expected to reproduce when neither modifier is provided.Only FLOAT tensors are supported.
Public Functions
-
inline explicit FlexAttention(const KernelContext &ctx)#
-
Tensor operator()(const Tensor &Q, const Tensor &K, const Tensor &V) const#
Computes the attention output for the given Q, K, V tensors using the default scaling factor
1 / sqrt(head_size).
-
Tensor operator()(const Tensor &Q, const Tensor &K, const Tensor &V, float scale) const#
Computes the attention output for the given Q, K, V tensors using an explicit
scalevalue (matching thescaleattribute of the operator).
-
void operator()(const Tensor &Q, const Tensor &K, const Tensor &V, float scale, Tensor &output) const#
In-place overload writing into a caller-allocated
outputtensor.outputmust already be a FLOAT tensor whose shape equals(batch_size, q_num_heads, q_seq_len, v_head_size)and whosedatabuffer has been sized to match.
Public Static Functions
-
static inline constexpr bool CanRunInPlace() noexcept#
FlexAttention computes a fresh output buffer from independent reads of Q, K, V and never aliases an input buffer.
Private Members
-
KernelContext ctx_#
-
inline explicit FlexAttention(const KernelContext &ctx)#
-
class FlexAttention#
-
namespace kernel
-
namespace onnx_backend_test