include_preview_kernels.h#

namespace ONNX_LIGHT_NAMESPACE
namespace onnx_backend_test
namespace kernel
class FlexAttention#
#include <include_preview_kernels.h>

Reference implementation of ai.onnx.preview::FlexAttention (v1).

Computes Y = Softmax((Q @ K^T) * scale, axis=-1) @ V over rank-4 (batched, multi-head) FLOAT inputs. Supports Grouped Query Attention (GQA): when q_num_heads != kv_num_heads each K/V head is shared by a contiguous group of query heads, i.e. query head h uses K/V head floor(h / (q_num_heads / kv_num_heads)); q_num_heads must be a multiple of kv_num_heads.

The optional score_mod and prob_mod modifier subgraphs of the upstream operator are not modeled by this reference kernel — it implements the un-modified baseline that backends are expected to reproduce when neither modifier is provided.

Only FLOAT tensors are supported.

Public Functions

inline explicit FlexAttention(const KernelContext &ctx)#
Tensor operator()(const Tensor &Q, const Tensor &K, const Tensor &V) const#

Computes the attention output for the given Q, K, V tensors using the default scaling factor 1 / sqrt(head_size).

Tensor operator()(const Tensor &Q, const Tensor &K, const Tensor &V, float scale) const#

Computes the attention output for the given Q, K, V tensors using an explicit scale value (matching the scale attribute of the operator).

void operator()(const Tensor &Q, const Tensor &K, const Tensor &V, float scale, Tensor &output) const#

In-place overload writing into a caller-allocated output tensor. output must already be a FLOAT tensor whose shape equals (batch_size, q_num_heads, q_seq_len, v_head_size) and whose data buffer has been sized to match.

Public Static Functions

static inline constexpr bool CanRunInPlace() noexcept#

FlexAttention computes a fresh output buffer from independent reads of Q, K, V and never aliases an input buffer.

Private Members

KernelContext ctx_#