.. _op_ai_onnx_preview_FlexAttention: FlexAttention ============= - **Domain**: ``ai.onnx.preview`` - **Since version**: 1 Computes scaled dot-product attention over rank-4 (batched, multi-head) inputs, with optional user-provided customization subgraphs at two stages: 1. score_mod: Modify the attention score tensor after Q·K^T 2. prob_mod: Modify the probability tensor after Softmax This operator mirrors the capabilities of PyTorch's flex_attention: https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html Input Shapes (MUST be rank-4 tensors): - Q: ``(batch_size, q_num_heads, q_sequence_length, head_size)`` - K: ``(batch_size, kv_num_heads, kv_sequence_length, head_size)`` - V: ``(batch_size, kv_num_heads, kv_sequence_length, v_head_size)`` Output Shape: - Y: ``(batch_size, q_num_heads, q_sequence_length, v_head_size)`` FlexAttention Computation: .. code-block:: Scores = (Q @ K^T) * scale Scores = score_mod(Scores) # if 'score_mod' is provided Probs = Softmax(Scores, axis=-1) Probs = prob_mod(Probs) # if 'prob_mod' is provided Y = Probs @ V Grouped Query Attention (GQA): When ``q_num_heads != kv_num_heads``, each K/V head is shared by a contiguous group of query heads in head-index order. Let ``group_size = q_num_heads / kv_num_heads``; then query head ``h`` uses K/V head ``floor(h / group_size)``. ``q_num_heads`` must be a multiple of ``kv_num_heads``. Modifier Subgraphs (score_mod, prob_mod): Each modifier subgraph takes exactly one rank-4 tensor input and must produce exactly one rank-4 tensor output of the same shape and element type. - score_mod input/output shape: ``(batch_size, q_num_heads, q_sequence_length, kv_sequence_length)`` - prob_mod input/output shape: ``(batch_size, q_num_heads, q_sequence_length, kv_sequence_length)`` The element type is determined by softmax_precision (defaults to float32 for non-double inputs, otherwise double). Masking can be expressed in score_mod by writing masked positions as -inf (or a large negative value appropriate for the target precision). **Inputs** - **Q** (*T1*): Query tensor with shape ``(batch_size, q_num_heads, q_seq_len, head_size)``. - **K** (*T1*): Key tensor with shape ``(batch_size, kv_num_heads, kv_seq_len, head_size)``. - **V** (*T1*): Value tensor with shape ``(batch_size, kv_num_heads, kv_seq_len, v_head_size)``. **Outputs** - **Y** (*T1*): Output tensor with shape ``(batch_size, q_num_heads, q_seq_len, v_head_size)``. **Type Constraints** - **T1**: Constrain Q, K, V to float tensors. Allowed types: tensor(bfloat16), tensor(double), tensor(float), tensor(float16). Examples -------- **test_cc_flex_attention_basic** .. code-block:: text Inputs: Q: shape=(1, 2, 2, 2), dtype=float32 [[[[ 1. , 0. ], [ 0. , 1. ]], [[ 0.5, 0.5], [ 1. , -1. ]]]] K: shape=(1, 2, 2, 2), dtype=float32 [[[[ 1., 0.], [ 0., 1.]], [[ 1., 1.], [-1., 1.]]]] V: shape=(1, 2, 2, 2), dtype=float32 [[[[ 1., 2.], [ 3., 4.]], [[-1., 0.], [ 0., 1.]]]] Outputs: Y: shape=(1, 2, 2, 2), dtype=float32 [[[[ 1.6604769 , 2.660477 ], [ 2.339523 , 3.339523 ]], [[-0.66976154, 0.33023846], [-0.80442965, 0.19557032]]]] **test_cc_flex_attention_gqa** .. code-block:: text Inputs: Q: shape=(1, 4, 2, 2), dtype=float32 [[[[ 0.1 , 0.2 ], [ 0.3 , 0.4 ]], [[-0.1 , 0.05], [ 0.2 , -0.3 ]], [[ 0.5 , 0.5 ], [ 0. , 1. ]], [[ 1. , 0. ], [ 0.5 , -0.5 ]]]] K: shape=(1, 2, 3, 2), dtype=float32 [[[[ 1. , 0. ], [ 0.5 , 0.5 ], [ 0. , 1. ]], [[-1. , 1. ], [ 1. , 1. ], [ 0.25, -0.5 ]]]] V: shape=(1, 2, 3, 2), dtype=float32 [[[[ 1. , 0. ], [ 0. , 1. ], [-1. , 1. ]], [[ 2. , -2. ], [ 0.5 , 0.25], [-0.5 , 0. ]]]] Outputs: Y: shape=(1, 4, 2, 2), dtype=float32 [[[[-0.02356532, 0.6783799 ], [-0.02356531, 0.6783799 ]], [[-0.03533878, 0.6841799 ], [ 0.11724145, 0.6063233 ]], [[ 0.6482418 , -0.37858847], [ 0.9917567 , -0.74587834]], [[ 0.37784207, -0.12898168], [ 0.29831943, -0.26321504]]]]