Understanding Multi-Head Attention for ML Framework Developers

This post is the outcome of my frustrations learning about MHA from the perspective of an ML framework developer. I am sharing this here in case others find it helpful.

Introduction

Multi-Head Attention (MHA) is an operator that was initially introduced as part of the Transformer architecture in the influential paper, "Attention is All You Need" by Vaswani et. al. Since its introduction, MHA has been incorporated into various Machine Learning (ML) frameworks, including Pytorch. While there is a wealth of resources aimed at ML practitioners that provide a high-level overview of MHA, they often do not delve into the specifics of its implementation. These details are quite helpful for ML Framework developers who have to provide implementations of MHA. Our post aims to address this gap by providing an ML framework developer a crash course that emphasizes how exactly the tensors are transformed.

First, we will look at how Scaled Dot Product Attention (SDPA), a component of MHA, operates on its input tensors. Then, we will provide some intuition on what SDPA is calculating - this section is optional. Next, we will look at the whole MHA operator. Finally we will explore the PyTorch operators F.scaled_dot_product_attention, F.multi_head_attention_forward, and torch.nn.MultiHeadAttention.

Scaled Dot Product Attention

Scaled Dot Product Attention (SDPA) is a component of the Multi-Head Attention operator. We will first present the definition of the operator, then we will build intuition on what this operator is doing using the soft database lookup interpretation. If you don’t care for the intuition, feel free to skip it.

Definition

Scaled dot product attention is defined as
image

where Q,K,V stand for queries, keys, and values respectively, and d is the dimension of the queries/keys (shared).
We explain why these inputs are called queries, keys, and values in the Soft DB Lookup Interpretation section.


Figure 1: Reference SDPA Data Flow Graph annotated with output sizes. We omit the batch dimension(s) for simplicity.

You will notice that there is an additional “mask” node in Figure 1 that is not in the formula. This node enables us to mask out certain entries that are either irrelevant or should not be used (e.g. don’t look at the word you’re supposed to be predicting in the context of training language models). The mask typically adds -inf values to nodes that should be ignored, which will turn into 0s after the softmax.

Soft Database Lookup Interpretation

Expand Optional Section

The SDPA operator can be thought of as performing q_len soft database lookups on a key-value store of kv_len key-value pairs.

The queries, keys, and values are all vectors.

First let us consider what a hard (regular) DB lookup would be. For each query Q[i]

  1. Find the most relevant key K[j]
  2. Return its corresponding value V[j]

Note the following:

  1. Queries and keys should be comparable - in practice this means they have the same dimension qk_dim
  2. There are as many keys as values (both kv_len) - but there may be a different number of queries (q_len)

In the soft lookup case, instead of finding the single most relevant key K[j] and returning its corresponding value, we calculate the relevance of each key to the query and return a weighted average of the values based on the relevance scores.

To see why the scaled_dot_product_attention operator calculates a soft lookup, we will first need to understand 2 tools used in SDPA: dot-product similarity and softmax. Then, we will present a node-by-node intuition on how SDPA calculates this soft database lookup.

Scaled Dot Product Similarity

A similarity measure calculates the similarity between vectors q and k. In the context of soft database lookup, we need a similarity measure when comparing a query against keys to decide how much weight to give each key. It is kind of the opposite of a distance metric (i.e. similarity is greater for similar vectors and smaller for dissimilar vectors). Some similarity measures include dot product and cosine.

Why do dot product and cosine work as similarity measures? Let’s look at cosine first because it is easier to understand:

Cosine

The cosine similarity between u and v is simply the cosine of the angle between them. We can calculate cosine similarity using the formula:

Why is cosine a good similarity measure? When two vectors are similar, they point in the same direction i.e. angle between them is small and the cosine of this angle is large and vice versa (see graph). However, notice that cosine ignores the magnitudes of the vectors and focuses only on direction.

Dot product

When the magnitudes of the vectors in question are meaningful, we would want a similarity measure that takes this into account.

Rearranging the cosine equation, we get

From the above formula it’s clear that the dot product is directly proportional to cosine. When the magnitudes of input

vectors are meaningful, it’s better to just take the dot product (which also has the added benefit of being computationally cheaper). The dot product is a good similarity measure for roughly the same reason cosine is - with the added benefit of accounting for magnitude.

Scaled Dot product

Dot products can grow really large for high dimensional vectors which can cause numerical issues. We can account for these by reducing the variance using a scaling factor. Vaswani et. al. use the scaled dot product similarity measure defined as:

image

where d is the dimension of u,v.

Softmax

The softmax function is best illustrated by the following diagram.

Figure 2: How softmax transforms a vector. Source

i.e. softmax(v) := exp(v)/sum(exp(v)).

Softmax can be thought of as a soft argmax function. As we will see later - this soft argmax intuition will be useful in understanding SDPA as a soft database lookup.

Softmax is great because it

  1. Respects the order of the input vector (notice the largest and smallest values occur in indices 2,3 in both the input to and output of softmax above)
  2. Output values will always be positive and sum to 1 (probability distribution)

But regular normalization also satisfies the above properties. So, why do we prefer softmax to a simpler normalization technique? That is out-of-scope for this post, but see this discussion for more details.

Soft Lookup Intuition: Node-by-Node

In the following table we build up the intuition for what each node/operator in the graph is doing culminating in an explanation for how the output is the soft database lookup.

Node Name Formula Shape Intuition + Notes
Q Q [q_len x qk_dim] There are q_len (independent) queries in Q.
Each query Q[i] is a qk_dim dimensional vector.
K K [kv_len x qk_dim] There are kv_len keys in K.
Each key K[j] is a qk_dim dimensional vector. Notice that kv_len can be different from q_len.
V V [kv_len x v_dim] There are kv_len values in V.
Each value V[j] is a v_dim dimensional vector. Notice that v_dim can be different from qk_dim.
A QKT [q_len x kv_len] A[q][k] is the dot product similarity between Q[i] and K[j]
A_scaled QKT/√d [q_len x kv_len] As discussed, we scale these dot products to prevent numerical issues when qk_dim is large.
A_probs softmax(QKT / √d) [q_len x kv_len] The softmax calculates the soft argmax of the values based on the relevance of their respective keys to the query.

If this were a hard DB lookup we would have done an argmax to find the best key for every query (so that A_probs[i][j]=1 if key j is the best match for query i A_probs[i][j]=0 otherwise).

Softmax gets us a soft version of this where 0<A_probs[i][j]<1 is a weight to put on the value corresponding to k when calculating the final lookup.
output softmax(QKT / √d)V [q_len x v_dim] output[i] is a v_dim dimension vector. It is the average of all values weighted by their relevance/similarity. This is the soft lookup for query i.

Multi-Head Attention

Multi-head attention performs multiple attention operations on different projections (“heads”) of the input data (keys, queries, values) in parallel. Multi-head attention requires that keys, queries, values be split into num_heads heads each - this is relaxed in variants like Grouped Query Attention (used in Llama), where there are multiple query heads for each key/value head. Each attention head can capture different features/relationships between input tokens.

As before, we will assume unbatched inputs. Multihead attention roughly does the following (compute operations are bolded):

  1. Linearly projects the queries, keys, values independently
  2. Splits the projected values along the embedding dimension into num_heads “heads”
  3. Computes scaled dot product attention for each head (heads act as (additional) batch dimension)
  4. Combines the heads back from the SDPA output.
  5. Linearly projects the output
  6. Return linearly projected output

Of course, to cajole the tensors into the right format, MHA implementations might use a combination of chunks, concats, permutes, reshapes, etc. The pytorch MHA operator handles many different tensor formats - each requiring its own series of data movement operations. We will cover some of this in the pytorch section.

Definition

The “Attention is All You Need” paper gives the definition of Multi-Head Attention as:


Unlike for SDPA, the closed form formula for MHA is rather confusing. It is easier to look at the following data flow graph to understand what is going on.

Figure 2: Reference MHA Data Flow Graph annotated with output sizes. We omit the batch dimension(s) for simplicity.

Pytorch Implementations

The diagrams provided above are for a naive/reference implementation. Lots of work has been done on creating far more efficient implementations for these operators in Pytorch and other frameworks.

Scaled Dot Product Attention

The SDPA operator is fairly straightforward. You can read more about the scaled_dot_product_attention implementation in the official docs. When available, a more efficient implementation like FlashAttention is used. One thing to note here is that the “mask’ can also be a boolean tensor where a True indicates that a value should take part in attention. This may be the opposite in other frameworks (like JAX).

Multi-Head Attention

There are two different ways to use Multi-Head Attention in pytorch: torch.nn.MultiHeadAttention and F.multi_head_attention_forward. The typical way to use MHA is to use torch.nn.MultiHeadAttention. The forward method of torch.nn.MultiHeadAttention is functionally equivalent to F.multi_head_attention_forward. In fact, under some conditions it simply calls the F.multi_head_attention_forward function. F.multi_head_attention_forward provides better insight into what the MHA operator actually does, if you want to go digging.

torch.nn.MultiHeadAttention

    def forward(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            key_padding_mask: Optional[Tensor] = None,
            need_weights: bool = True,
            attn_mask: Optional[Tensor] = None,
            average_attn_weights: bool = True,
            is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:

        why_not_fast_path = '' 
	# some checks to figure out whether we can use fast_path
        if not why_not_fast_path:
            # set up and do fast path inference

        any_nested = query.is_nested or key.is_nested or value.is_nested
        assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
                                f"The fast path was not hit because {why_not_fast_path}")

        if self.batch_first and is_batched:
            # make sure that the transpose op does not affect the "is" property
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = (x.transpose(1, 0) for x in (query, key))
                    value = key
            else:
                	query, key, value = (x.transpose(1, 0) for x in (query, key, value))

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight,
                average_attn_weights=average_attn_weights,
                is_causal=is_causal)
        else:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                average_attn_weights=average_attn_weights,
                is_causal=is_causal)
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights

Snippet: MultiHeadAttention.forward condensed

This module’s forward method is functionally identical to F.multi_head_attention_forward, but has some optional pre-processing and routing to a faster implementation (fast_path). You can find the full list of arguments to the constructor and the forward method in the docs. However, here are some peculiarities worth calling out for framework engineers:

Batches

So far we assumed no batch dimension. MHA can handle batched input. By default the batched tensors should be (seq_len, batch_size, embed_dim). I.e. the batch dimension is the middle dimension. However, you can set batch_first=True in torch.nn.MultiHeadAttention to get it to be the outermost/first dimension. We found no convincing answers on why the default is False. Our best guess is that this is due to historical reasons.

Return Type, need_weights, average_attn_weights

This operator returns a tuple, with the first value being the result of MHA described above. If need_weights=True, the second returned value will be the attention matrix computed inside the scaled_dot_product_attention step (i.e. softmax(QKT / √d)).

Why does this option exist? The attention matrix can help in analysis/interpretations of network weights. However, this requires materializing the softmax node fully - something that fast implementations like FlashAttention skip. Turning this on forces a much slower computation of SDPA. This option can likely be turned off in most production scenarios for best performance, but the default is True. If average_attn_weights=True,the attention matrix will be averaged across the heads.

Packed Input Projection Computation

If the embedding dimensions of Q,K,V are the same, we just use one concatenated in_proj_weight as opposed to {q,k,v}_proj_weight. We set unused kwarg(s) to None.

Source and Target Sequences

For pedagogical purposes we used shape variables like q_len and kv_len. However, the code (and literature) will use tgt_seq_len and src_seq_len (where the queries are the “target” sequence and the key-value pairs are the “source” sequence).

Acknowledgements

Many thanks to Prajjwal Bhargava, Jordan Fix, Kaustubh Gondkar, Jongsoo Park, Blaine Burton Rister for reviewing this post and offering insightful comments.

5 Likes

The reason batch_first is False by default is because LSTMs and Linear RNNs used to be the norm in sequence to sequence modeling, and when Transformers came about, we were trying to match the output formatting of MHA to match LSTM/RNNs.

The reason why nn.LSTM/nn.RNN defaulted to (seq, batch, size) is because CuDNN provided LSTM and RNN kernels that used that layout. It was easier to write high performance LSTM/RNN kernels where the outer dimension was sequence length – and so CuDNN defaulted to that – and so PyTorch defaulted to that to leverage these kernels.

So, in summary, no good reason why nn.MHA defaults to batch_first=False other than that it tried to match nn.LSTM which tried to use CuDNN optimally.

1 Like