Added Grouped Query Attention to scaled_dot_product_attention API

Grouped Query Attention in SDPA: PR#128898

Grouped Query Attention (GQA) has emerged as an important technique to reduce the memory usage of the kv cache during inference. It has become increasingly popular in many foundational LLM models like llama2 70b and llama3. We have added this support to SDPA.

Reference paper: [2305.13245] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

API Updates:

  • Added a new kwarg enable_gqa:Bool to the existing scaled_dot_product_attention function. The default value is False (which would ensure regular SDPA functionality).
  • The GQA cannot be used as default, as this memory layout is not supported by strides on, hence it needs to be explicitly enabled.
  • The last third dimension (-3) in the query, key and value tensor has been set as a dedicated head-dimension.

Important Notes:

  • GQA is supported only by math and flash_attention kernels.

Implementation details:

  • Following constraint are validated before performing GQA
  • Query.head_dim % Key.head_dim == 0
  • Query.head_dim % Value.head_dim == 0
  • Key.head_dim == Value.head_dim
  • If all conditions pass, in math kernel repeat_interleave is performed on the key and value tensors to match the query tensor’s head dimension, while in flash attention repeat interleave is not needed, hence we conserve memory in the forward pass.

# Sample call to SDPA - GQ
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)
output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa = True)

# Output Shape
(batch, 32, seq_len_q, D)

Benchmarking:

TorchTitan: PR#458

Profiling using Perfetto, on running TorchTitan training for Llama-8b

When SDPA is called without GQA, aten::reshape is called for both key and value tensor. The reshape call takes approximately 120us per call, which means approx 240us for 2 calls, while the flash_fwd_kernel takes approximately 3ms. In an ideal scenario, by removing these calls from the FlashAttention kernel, each SDPA flash kernel run will save 240us, which makes it ~6% faster than the previous runtime.

SDPA call without enable_gqa

SDPA call with enable_gqa

SDPA Benchmarking: PR#130634

Graphs representing SDPA function run time with different parameters (batch size, number of key-value heads, number of query heads) with enable_gqa=True and enable_gqa=False.

Batch size q_num_heads kv_num_heads q_seq_len kv_seq_len embed_dim forward_time when enable_gqa=True (ms) forward_time when enable_gqa=False (ms)
1 32 8 2048 2048 2048 100.71 119.70
8 32 8 2048 2048 2048 539.78 628.83
16 32 8 2048 2048 2048 1056.81 1225.48
32 32 8 2048 2048 2048 2099.54 2440.45

Batch size q_num_heads kv_num_heads q_seq_len kv_seq_len embed_dim forward_time when enable_gqa=True (ms) forward_time when enable_gqa=False (ms)
1 128 16 2048 2048 2048 243.30 260.51
8 128 16 2048 2048 2048 1766.17 1856.47
16 128 16 2048 2048 2048 3515.05 3675.95
32 128 16 2048 2048 2048 6996.68 7318.03

2 Likes