Authors: Lucas Pasqualin, Michael Gschwind
With the rise of generative models, and the recently added support for custom kernel acceleration during training with Better Transformer, we noticed an opportunity to improve performance of training GPT models.
Specifically, this note explores using custom kernel implementations of sdpa (scaled dot product attention) during training. The custom kernel for SDPA essentially replace several sequential operations with one functional block - which is where we expect to see performance gains.
For the implementation of GPT, we relied on nanoGPT. In Andrej Karpathy’s (the author) words, nanoGPT is “The simplest, fastest repository for training/finetuning medium-sized GPTs”.
Using the mem_efficient kernel results in a ~15.5% faster training time per batch, going from a ~154ms/batch baseline to ~130ms/batch.
Our implementation probably has better overflow protection
Torch’s sdpa was implemented in place here replacing the author’s version. For the purposes of this experiment, dropout was set to
0.0 (vs. the original
0.1 ). Models are also all trained using
A note about custom kernels - Better Transformer supports multiple different kernels optimized for specific use cases, with specific requirements. A kernel picker picks the best kernel for a particular combination of input parameters. If no optimized “custom kernel” for a particular combination of input parameters can be identified, the kernel picker selects a general kernel that can handle all input combinations.
In particular, all custom kernels supported by BT today only support the causal mask when it is specified using the
is_causal boolean. When a mask is specified, the general-purpose kernel will be selected because it is too expensive to analyze the contents of a provided mask to determine if it is the causal mask.
self.bias , which is passed in as the attention mask in line 55, is defined here:
Torch implementation using mem_efficient kernel (is_causal=True, attn_mask=None):
Since all constraints for BT mem_efficient kernel are satisfied, the following code runs the optimized kernel.
Torch implementation using generic kernel (is_causal=False, original attn_mask)
In this implementation, we keep the original mask but lose the optimized kernel, falling back to the generic implementation.
Training Performance Gain
Training Time (ms) per Batch vs. Sdpa Kernel
The above figure demonstrates the performance gained using Pytorch custom kernels. Here are the exact figures:
- baseline (nanoGPT implementation): ~154ms
- sdpa_math (generic): ~147ms (4.54% faster)
mem_efficientkernel: ~130ms (15.58% faster)
In addition to being faster, torch’s implementation may also be more numerically stable. There is a great explanation here, but essentially the pytorch implementation scales the Query and Key matrices before multiplication, which is said to be more stable and avoid overflow.
wBoiled down to the essentials, the differences in implementation are noted here:
# notice q is not scaled att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) att = F.softmax(att, dim=-1) # att = self.attn_dropout(att) # dropout == 0 y_nanogpt = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# notice q _is_ scaled here embed_size = q.size(-1) scaling_factor = math.sqrt(math.sqrt(embed_size))q = q / scaling_factor att = q @ (k.transpose(-2, -1) / scaling_factor) # same as above att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))att = F.softmax(att0, dim=-1) # att = self.attn_dropout(att) # dropout == 0 y_scale_before = att0 @ v
Mathematically both approaches should be equivalent, but in our experimentation we find that
y_nanogpt does not equal
y_scale_before was verified as being equivalent to calling
y_nanogpt does not match the expected output.
To verify this is an overflow issue, we implemented the provided code and compared the output using
torch.allclose . Demonstrated below:
y_sdpa=torch.nn.functional._scaled_dot_product_attention( q, k, v, attn_mask=self.bias[:,:,:T,:T] != 0, dropout_p=0.0, need_attn_weights=False, is_causal=False ) torch.allclose(y_sdpa, y_nanogpt) # False, most likely indicating overflows torch.allclose(y_sdpa, y_scale_before) # True, as expected
As an additional sanity check, we trained a model using the original
nanoGPT implementation for around ~450k iterations, and produced sample inferences using both the original and torch sdpa implementations. No noticeable differences are found between both samples:
- Baseline implementation - Output of python sample.py using baseline sdpa - Pastebin.com
- Using torch sdpa- Output of `python sample.py` using `_scaled_dot_product_attention` - Pastebin.com
Environment / Setup
All code was run on a server with (8 x NVIDIA Corporation GA100 [A100 SXM4 80GB])
Environment followed these steps:
git clone https://github.com/karpathy/nanoGPT.git
- Install Conda - Installation — conda 23.1.0.post16+5a9deaf5a documentation
conda install pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidiawith-proxy conda install numpy conda install datasets pip install tiktoken pip install wandb conda install tqdm