FlexAttention#

  • Domain: ai.onnx.preview

  • Since version: 1

Computes scaled dot-product attention over rank-4 (batched, multi-head) inputs, with optional user-provided customization subgraphs at two stages:

  1. score_mod: Modify the attention score tensor after Q·K^T

  2. prob_mod: Modify the probability tensor after Softmax

This operator mirrors the capabilities of PyTorch’s flex_attention: https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html

Input Shapes (MUST be rank-4 tensors):

  • Q: (batch_size, q_num_heads, q_sequence_length, head_size)

  • K: (batch_size, kv_num_heads, kv_sequence_length, head_size)

  • V: (batch_size, kv_num_heads, kv_sequence_length, v_head_size)

Output Shape:

  • Y: (batch_size, q_num_heads, q_sequence_length, v_head_size)

FlexAttention Computation:

Scores = (Q @ K^T) * scale
Scores = score_mod(Scores)             # if 'score_mod' is provided
Probs = Softmax(Scores, axis=-1)
Probs = prob_mod(Probs)                # if 'prob_mod' is provided
Y = Probs @ V

Grouped Query Attention (GQA): When q_num_heads != kv_num_heads, each K/V head is shared by a contiguous group of query heads in head-index order. Let group_size = q_num_heads / kv_num_heads; then query head h uses K/V head floor(h / group_size). q_num_heads must be a multiple of kv_num_heads.

Modifier Subgraphs (score_mod, prob_mod): Each modifier subgraph takes exactly one rank-4 tensor input and must produce exactly one rank-4 tensor output of the same shape and element type.

  • score_mod input/output shape: (batch_size, q_num_heads, q_sequence_length, kv_sequence_length)

  • prob_mod input/output shape: (batch_size, q_num_heads, q_sequence_length, kv_sequence_length)

The element type is determined by softmax_precision (defaults to float32 for non-double inputs, otherwise double).

Masking can be expressed in score_mod by writing masked positions as -inf (or a large negative value appropriate for the target precision).

Inputs

  • Q (T1): Query tensor with shape (batch_size, q_num_heads, q_seq_len, head_size).

  • K (T1): Key tensor with shape (batch_size, kv_num_heads, kv_seq_len, head_size).

  • V (T1): Value tensor with shape (batch_size, kv_num_heads, kv_seq_len, v_head_size).

Outputs

  • Y (T1): Output tensor with shape (batch_size, q_num_heads, q_seq_len, v_head_size).

Type Constraints

  • T1: Constrain Q, K, V to float tensors. Allowed types: tensor(bfloat16), tensor(double), tensor(float), tensor(float16).

Examples#

test_cc_flex_attention_basic

Inputs:
  Q: shape=(1, 2, 2, 2), dtype=float32
    [[[[ 1. ,  0. ],
       [ 0. ,  1. ]],

      [[ 0.5,  0.5],
       [ 1. , -1. ]]]]
  K: shape=(1, 2, 2, 2), dtype=float32
    [[[[ 1.,  0.],
       [ 0.,  1.]],

      [[ 1.,  1.],
       [-1.,  1.]]]]
  V: shape=(1, 2, 2, 2), dtype=float32
    [[[[ 1.,  2.],
       [ 3.,  4.]],

      [[-1.,  0.],
       [ 0.,  1.]]]]

Outputs:
  Y: shape=(1, 2, 2, 2), dtype=float32
    [[[[ 1.6604769 ,  2.660477  ],
       [ 2.339523  ,  3.339523  ]],

      [[-0.66976154,  0.33023846],
       [-0.80442965,  0.19557032]]]]

test_cc_flex_attention_gqa

Inputs:
  Q: shape=(1, 4, 2, 2), dtype=float32
    [[[[ 0.1 ,  0.2 ],
       [ 0.3 ,  0.4 ]],

      [[-0.1 ,  0.05],
       [ 0.2 , -0.3 ]],

      [[ 0.5 ,  0.5 ],
       [ 0.  ,  1.  ]],

      [[ 1.  ,  0.  ],
       [ 0.5 , -0.5 ]]]]
  K: shape=(1, 2, 3, 2), dtype=float32
    [[[[ 1.  ,  0.  ],
       [ 0.5 ,  0.5 ],
       [ 0.  ,  1.  ]],

      [[-1.  ,  1.  ],
       [ 1.  ,  1.  ],
       [ 0.25, -0.5 ]]]]
  V: shape=(1, 2, 3, 2), dtype=float32
    [[[[ 1.  ,  0.  ],
       [ 0.  ,  1.  ],
       [-1.  ,  1.  ]],

      [[ 2.  , -2.  ],
       [ 0.5 ,  0.25],
       [-0.5 ,  0.  ]]]]

Outputs:
  Y: shape=(1, 4, 2, 2), dtype=float32
    [[[[-0.02356532,  0.6783799 ],
       [-0.02356531,  0.6783799 ]],

      [[-0.03533878,  0.6841799 ],
       [ 0.11724145,  0.6063233 ]],

      [[ 0.6482418 , -0.37858847],
       [ 0.9917567 , -0.74587834]],

      [[ 0.37784207, -0.12898168],
       [ 0.29831943, -0.26321504]]]]