FlexAttention#
Domain:
ai.onnx.previewSince version: 1
Computes scaled dot-product attention over rank-4 (batched, multi-head) inputs, with optional user-provided customization subgraphs at two stages:
score_mod: Modify the attention score tensor after Q·K^T
prob_mod: Modify the probability tensor after Softmax
This operator mirrors the capabilities of PyTorch’s flex_attention: https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html
Input Shapes (MUST be rank-4 tensors):
Q:
(batch_size, q_num_heads, q_sequence_length, head_size)K:
(batch_size, kv_num_heads, kv_sequence_length, head_size)V:
(batch_size, kv_num_heads, kv_sequence_length, v_head_size)
Output Shape:
Y:
(batch_size, q_num_heads, q_sequence_length, v_head_size)
FlexAttention Computation:
Scores = (Q @ K^T) * scale
Scores = score_mod(Scores) # if 'score_mod' is provided
Probs = Softmax(Scores, axis=-1)
Probs = prob_mod(Probs) # if 'prob_mod' is provided
Y = Probs @ V
Grouped Query Attention (GQA):
When q_num_heads != kv_num_heads, each K/V head is shared by a contiguous
group of query heads in head-index order. Let
group_size = q_num_heads / kv_num_heads; then query head h uses K/V head
floor(h / group_size). q_num_heads must be a multiple of
kv_num_heads.
Modifier Subgraphs (score_mod, prob_mod): Each modifier subgraph takes exactly one rank-4 tensor input and must produce exactly one rank-4 tensor output of the same shape and element type.
score_mod input/output shape:
(batch_size, q_num_heads, q_sequence_length, kv_sequence_length)prob_mod input/output shape:
(batch_size, q_num_heads, q_sequence_length, kv_sequence_length)
The element type is determined by softmax_precision (defaults to float32 for non-double inputs, otherwise double).
Masking can be expressed in score_mod by writing masked positions as -inf (or a large negative value appropriate for the target precision).
Inputs
Q (T1): Query tensor with shape
(batch_size, q_num_heads, q_seq_len, head_size).K (T1): Key tensor with shape
(batch_size, kv_num_heads, kv_seq_len, head_size).V (T1): Value tensor with shape
(batch_size, kv_num_heads, kv_seq_len, v_head_size).
Outputs
Y (T1): Output tensor with shape
(batch_size, q_num_heads, q_seq_len, v_head_size).
Type Constraints
T1: Constrain Q, K, V to float tensors. Allowed types: tensor(bfloat16), tensor(double), tensor(float), tensor(float16).
Examples#
test_cc_flex_attention_basic
Inputs:
Q: shape=(1, 2, 2, 2), dtype=float32
[[[[ 1. , 0. ],
[ 0. , 1. ]],
[[ 0.5, 0.5],
[ 1. , -1. ]]]]
K: shape=(1, 2, 2, 2), dtype=float32
[[[[ 1., 0.],
[ 0., 1.]],
[[ 1., 1.],
[-1., 1.]]]]
V: shape=(1, 2, 2, 2), dtype=float32
[[[[ 1., 2.],
[ 3., 4.]],
[[-1., 0.],
[ 0., 1.]]]]
Outputs:
Y: shape=(1, 2, 2, 2), dtype=float32
[[[[ 1.6604769 , 2.660477 ],
[ 2.339523 , 3.339523 ]],
[[-0.66976154, 0.33023846],
[-0.80442965, 0.19557032]]]]
test_cc_flex_attention_gqa
Inputs:
Q: shape=(1, 4, 2, 2), dtype=float32
[[[[ 0.1 , 0.2 ],
[ 0.3 , 0.4 ]],
[[-0.1 , 0.05],
[ 0.2 , -0.3 ]],
[[ 0.5 , 0.5 ],
[ 0. , 1. ]],
[[ 1. , 0. ],
[ 0.5 , -0.5 ]]]]
K: shape=(1, 2, 3, 2), dtype=float32
[[[[ 1. , 0. ],
[ 0.5 , 0.5 ],
[ 0. , 1. ]],
[[-1. , 1. ],
[ 1. , 1. ],
[ 0.25, -0.5 ]]]]
V: shape=(1, 2, 3, 2), dtype=float32
[[[[ 1. , 0. ],
[ 0. , 1. ],
[-1. , 1. ]],
[[ 2. , -2. ],
[ 0.5 , 0.25],
[-0.5 , 0. ]]]]
Outputs:
Y: shape=(1, 4, 2, 2), dtype=float32
[[[[-0.02356532, 0.6783799 ],
[-0.02356531, 0.6783799 ]],
[[-0.03533878, 0.6841799 ],
[ 0.11724145, 0.6063233 ]],
[[ 0.6482418 , -0.37858847],
[ 0.9917567 , -0.74587834]],
[[ 0.37784207, -0.12898168],
[ 0.29831943, -0.26321504]]]]