Grouped Query Attention in SDPA: PR#128898
Grouped Query Attention (GQA) has emerged as an important technique to reduce the memory usage of the kv cache during inference. It has become increasingly popular in many foundational LLM models like llama2 70b and llama3. We have added this support to SDPA.
Reference paper: [2305.13245] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
API Updates:
- Added a new kwarg enable_gqa:Bool to the existing scaled_dot_product_attention function. The default value is False (which would ensure regular SDPA functionality).
- The GQA cannot be used as default, as this memory layout is not supported by strides on, hence it needs to be explicitly enabled.
- The last third dimension (-3) in the query, key and value tensor has been set as a dedicated head-dimension.
Important Notes:
- GQA is supported only by math and flash_attention kernels.
Implementation details:
- Following constraint are validated before performing GQA
- Query.head_dim % Key.head_dim == 0
- Query.head_dim % Value.head_dim == 0
- Key.head_dim == Value.head_dim
- If all conditions pass, in math kernel repeat_interleave is performed on the key and value tensors to match the query tensor’s head dimension, while in flash attention repeat interleave is not needed, hence we conserve memory in the forward pass.
# Sample call to SDPA - GQ
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)
output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa = True)
# Output Shape
(batch, 32, seq_len_q, D)
Benchmarking:
TorchTitan: PR#458
Profiling using Perfetto, on running TorchTitan training for Llama-8b
When SDPA is called without GQA, aten::reshape is called for both key and value tensor. The reshape call takes approximately 120us per call, which means approx 240us for 2 calls, while the flash_fwd_kernel takes approximately 3ms. In an ideal scenario, by removing these calls from the FlashAttention kernel, each SDPA flash kernel run will save 240us, which makes it ~6% faster than the previous runtime.
SDPA call without enable_gqa
SDPA call with enable_gqa
SDPA Benchmarking: PR#130634
Graphs representing SDPA function run time with different parameters (batch size, number of key-value heads, number of query heads) with enable_gqa=True and enable_gqa=False.
Batch size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True (ms) | forward_time when enable_gqa=False (ms) |
---|---|---|---|---|---|---|---|
1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 |
8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 |
16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 |
32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |
Batch size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True (ms) | forward_time when enable_gqa=False (ms) |
---|---|---|---|---|---|---|---|
1 | 128 | 16 | 2048 | 2048 | 2048 | 243.30 | 260.51 |
8 | 128 | 16 | 2048 | 2048 | 2048 | 1766.17 | 1856.47 |
16 | 128 | 16 | 2048 | 2048 | 2048 | 3515.05 | 3675.95 |
32 | 128 | 16 | 2048 | 2048 | 2048 | 6996.68 | 7318.03 |