This post is the outcome of my frustrations learning about MHA from the perspective of an ML framework developer. I am sharing this here in case others find it helpful.
Introduction
Multi-Head Attention (MHA) is an operator that was initially introduced as part of the Transformer architecture in the influential paper, "Attention is All You Need" by Vaswani et. al. Since its introduction, MHA has been incorporated into various Machine Learning (ML) frameworks, including Pytorch. While there is a wealth of resources aimed at ML practitioners that provide a high-level overview of MHA, they often do not delve into the specifics of its implementation. These details are quite helpful for ML Framework developers who have to provide implementations of MHA. Our post aims to address this gap by providing an ML framework developer a crash course that emphasizes how exactly the tensors are transformed.
First, we will look at how Scaled Dot Product Attention (SDPA), a component of MHA, operates on its input tensors. Then, we will provide some intuition on what SDPA is calculating - this section is optional. Next, we will look at the whole MHA operator. Finally we will explore the PyTorch operators F.scaled_dot_product_attention
, F.multi_head_attention_forward
, and torch.nn.MultiHeadAttention
.
Scaled Dot Product Attention
Scaled Dot Product Attention (SDPA) is a component of the Multi-Head Attention operator. We will first present the definition of the operator, then we will build intuition on what this operator is doing using the soft database lookup interpretation. If you don’t care for the intuition, feel free to skip it.
Definition
Scaled dot product attention is defined as
where Q,K,V stand for queries, keys, and values respectively, and d is the dimension of the queries/keys (shared).
We explain why these inputs are called queries, keys, and values in the Soft DB Lookup Interpretation section.
Figure 1: Reference SDPA Data Flow Graph annotated with output sizes. We omit the batch dimension(s) for simplicity.
You will notice that there is an additional “mask” node in Figure 1 that is not in the formula. This node enables us to mask out certain entries that are either irrelevant or should not be used (e.g. don’t look at the word you’re supposed to be predicting in the context of training language models). The mask typically adds -inf
values to nodes that should be ignored, which will turn into 0s after the softmax.
Soft Database Lookup Interpretation
Expand Optional Section
The SDPA operator can be thought of as performing q_len soft database lookups on a key-value store of kv_len
key-value pairs.
The queries, keys, and values are all vectors.
First let us consider what a hard (regular) DB lookup would be. For each query Q[i]
- Find the most relevant key
K[j]
- Return its corresponding value
V[j]
Note the following:
- Queries and keys should be comparable - in practice this means they have the same dimension
qk_dim
- There are as many keys as values (both
kv_len
) - but there may be a different number of queries (q_len
)
In the soft lookup case, instead of finding the single most relevant key K[j]
and returning its corresponding value, we calculate the relevance of each key to the query and return a weighted average of the values based on the relevance scores.
To see why the scaled_dot_product_attention
operator calculates a soft lookup, we will first need to understand 2 tools used in SDPA: dot-product similarity and softmax. Then, we will present a node-by-node intuition on how SDPA calculates this soft database lookup.
Scaled Dot Product Similarity
A similarity measure calculates the similarity between vectors q and k. In the context of soft database lookup, we need a similarity measure when comparing a query against keys to decide how much weight to give each key. It is kind of the opposite of a distance metric (i.e. similarity is greater for similar vectors and smaller for dissimilar vectors). Some similarity measures include dot product and cosine.
Why do dot product and cosine work as similarity measures? Let’s look at cosine first because it is easier to understand:
Cosine
The cosine similarity between u and v is simply the cosine of the angle between them. We can calculate cosine similarity using the formula:
Why is cosine a good similarity measure? When two vectors are similar, they point in the same direction i.e. angle between them is small and the cosine of this angle is large and vice versa (see graph). However, notice that cosine ignores the magnitudes of the vectors and focuses only on direction.
Dot product
When the magnitudes of the vectors in question are meaningful, we would want a similarity measure that takes this into account.
Rearranging the cosine equation, we get
From the above formula it’s clear that the dot product is directly proportional to cosine. When the magnitudes of input
vectors are meaningful, it’s better to just take the dot product (which also has the added benefit of being computationally cheaper). The dot product is a good similarity measure for roughly the same reason cosine is - with the added benefit of accounting for magnitude.
Scaled Dot product
Dot products can grow really large for high dimensional vectors which can cause numerical issues. We can account for these by reducing the variance using a scaling factor. Vaswani et. al. use the scaled dot product similarity measure defined as:
where d is the dimension of u,v.
Softmax
The softmax function is best illustrated by the following diagram.
Figure 2: How softmax transforms a vector. Source
i.e. softmax(v) := exp(v)/sum(exp(v))
.
Softmax can be thought of as a soft argmax function. As we will see later - this soft argmax intuition will be useful in understanding SDPA as a soft database lookup.
Softmax is great because it
- Respects the order of the input vector (notice the largest and smallest values occur in indices 2,3 in both the input to and output of softmax above)
- Output values will always be positive and sum to 1 (probability distribution)
But regular normalization also satisfies the above properties. So, why do we prefer softmax to a simpler normalization technique? That is out-of-scope for this post, but see this discussion for more details.
Soft Lookup Intuition: Node-by-Node
In the following table we build up the intuition for what each node/operator in the graph is doing culminating in an explanation for how the output is the soft database lookup.
Node Name | Formula | Shape | Intuition + Notes |
---|---|---|---|
Q | Q | [q_len x qk_dim] |
There are q_len (independent) queries in Q .Each query Q[i] is a qk_dim dimensional vector. |
K | K | [kv_len x qk_dim] |
There are kv_len keys in K .Each key K[j] is a qk_dim dimensional vector. Notice that kv_len can be different from q_len. |
V | V | [kv_len x v_dim] |
There are kv_len values in V.Each value V[j] is a v_dim dimensional vector. Notice that v_dim can be different from qk_dim . |
A | QKT | [q_len x kv_len] |
A[q][k] is the dot product similarity between Q[i] and K[j] |
A_scaled |
QKT/√d | [q_len x kv_len] |
As discussed, we scale these dot products to prevent numerical issues when qk_dim is large. |
A_probs |
softmax(QKT / √d) | [q_len x kv_len] |
The softmax calculates the soft argmax of the values based on the relevance of their respective keys to the query. If this were a hard DB lookup we would have done an argmax to find the best key for every query (so that A_probs[i][j]=1 if key j is the best match for query i A_probs[i][j]=0 otherwise).Softmax gets us a soft version of this where 0< A_probs[i][j] <1 is a weight to put on the value corresponding to k when calculating the final lookup. |
output |
softmax(QKT / √d)V | [q_len x v_dim] |
output[i] is a v_dim dimension vector. It is the average of all values weighted by their relevance/similarity. This is the soft lookup for query i . |
Multi-Head Attention
Multi-head attention performs multiple attention operations on different projections (“heads”) of the input data (keys, queries, values) in parallel. Multi-head attention requires that keys, queries, values be split into num_heads
heads each - this is relaxed in variants like Grouped Query Attention (used in Llama), where there are multiple query heads for each key/value head. Each attention head can capture different features/relationships between input tokens.
As before, we will assume unbatched inputs. Multihead attention roughly does the following (compute operations are bolded):
- Linearly projects the queries, keys, values independently
- Splits the projected values along the embedding dimension into num_heads “heads”
- Computes scaled dot product attention for each head (heads act as (additional) batch dimension)
- Combines the heads back from the SDPA output.
- Linearly projects the output
- Return linearly projected output
Of course, to cajole the tensors into the right format, MHA implementations might use a combination of chunks, concats, permutes, reshapes, etc. The pytorch MHA operator handles many different tensor formats - each requiring its own series of data movement operations. We will cover some of this in the pytorch section.
Definition
The “Attention is All You Need” paper gives the definition of Multi-Head Attention as:
Unlike for SDPA, the closed form formula for MHA is rather confusing. It is easier to look at the following data flow graph to understand what is going on.
Figure 2: Reference MHA Data Flow Graph annotated with output sizes. We omit the batch dimension(s) for simplicity.
Pytorch Implementations
The diagrams provided above are for a naive/reference implementation. Lots of work has been done on creating far more efficient implementations for these operators in Pytorch and other frameworks.
Scaled Dot Product Attention
The SDPA operator is fairly straightforward. You can read more about the scaled_dot_product_attention
implementation in the official docs. When available, a more efficient implementation like FlashAttention is used. One thing to note here is that the “mask’ can also be a boolean tensor where a True indicates that a value should take part in attention. This may be the opposite in other frameworks (like JAX).
Multi-Head Attention
There are two different ways to use Multi-Head Attention in pytorch: torch.nn.MultiHeadAttention
and F.multi_head_attention_forward
. The typical way to use MHA is to use torch.nn.MultiHeadAttention
. The forward method of torch.nn.MultiHeadAttention
is functionally equivalent to F.multi_head_attention_forward
. In fact, under some conditions it simply calls the F.multi_head_attention_forward
function. F.multi_head_attention_forward
provides better insight into what the MHA operator actually does, if you want to go digging.
torch.nn.MultiHeadAttention
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
why_not_fast_path = ''
# some checks to figure out whether we can use fast_path
if not why_not_fast_path:
# set up and do fast path inference
any_nested = query.is_nested or key.is_nested or value.is_nested
assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
f"The fast path was not hit because {why_not_fast_path}")
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = (x.transpose(1, 0) for x in (query, key))
value = key
else:
query, key, value = (x.transpose(1, 0) for x in (query, key, value))
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
average_attn_weights=average_attn_weights,
is_causal=is_causal)
else:
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
average_attn_weights=average_attn_weights,
is_causal=is_causal)
if self.batch_first and is_batched:
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
Snippet: MultiHeadAttention.forward
condensed
This module’s forward method is functionally identical to F.multi_head_attention_forward
, but has some optional pre-processing and routing to a faster implementation (fast_path). You can find the full list of arguments to the constructor and the forward method in the docs. However, here are some peculiarities worth calling out for framework engineers:
Batches
So far we assumed no batch dimension. MHA can handle batched input. By default the batched tensors should be (seq_len, batch_size, embed_dim)
. I.e. the batch dimension is the middle dimension. However, you can set batch_first=True
in torch.nn.MultiHeadAttention
to get it to be the outermost/first dimension. We found no convincing answers on why the default is False
. Our best guess is that this is due to historical reasons.
Return Type, need_weights, average_attn_weights
This operator returns a tuple, with the first value being the result of MHA described above. If need_weights=True, the second returned value will be the attention matrix computed inside the scaled_dot_product_attention
step (i.e. softmax(QKT / √d)).
Why does this option exist? The attention matrix can help in analysis/interpretations of network weights. However, this requires materializing the softmax node fully - something that fast implementations like FlashAttention skip. Turning this on forces a much slower computation of SDPA. This option can likely be turned off in most production scenarios for best performance, but the default is True
. If average_attn_weights=True
,the attention matrix will be averaged across the heads.
Packed Input Projection Computation
If the embedding dimensions of Q
,K
,V
are the same, we just use one concatenated in_proj_weight
as opposed to {q,k,v}_proj_weight
. We set unused kwarg(s) to None
.
Source and Target Sequences
For pedagogical purposes we used shape variables like q_len
and kv_len
. However, the code (and literature) will use tgt_seq_len
and src_seq_len
(where the queries are the “target” sequence and the key-value pairs are the “source” sequence).
Acknowledgements
Many thanks to Prajjwal Bhargava, Jordan Fix, Kaustubh Gondkar, Jongsoo Park, Blaine Burton Rister for reviewing this post and offering insightful comments.