Authors: Lucas Pasqualin, Michael Gschwind
With the rise of generative models, and the recently added support for custom kernel acceleration during training with Better Transformer, we noticed an opportunity to improve performance of training GPT models.
Specifically, this note explores using custom kernel implementations of sdpa (scaled dot product attention) during training. The custom kernel for SDPA essentially replace several sequential operations with one functional block - which is where we expect to see performance gains.
For the implementation of GPT, we relied on nanoGPT. In Andrej Karpathy’s (the author) words, nanoGPT is “The simplest, fastest repository for training/finetuning medium-sized GPTs”.
tl;dr
-
Using the mem_efficient kernel results in a ~15.5% faster training time per batch, going from a ~154ms/batch baseline to ~130ms/batch.
-
Our implementation probably has better overflow protection
Code
Torch’s sdpa was implemented in place here replacing the author’s version. For the purposes of this experiment, dropout was set to 0.0
(vs. the original 0.1
). Models are also all trained using torch.compile
A note about custom kernels - Better Transformer supports multiple different kernels optimized for specific use cases, with specific requirements. A kernel picker picks the best kernel for a particular combination of input parameters. If no optimized “custom kernel” for a particular combination of input parameters can be identified, the kernel picker selects a general kernel that can handle all input combinations.
In particular, all custom kernels supported by BT today only support the causal mask when it is specified using the is_causal
boolean. When a mask is specified, the general-purpose kernel will be selected because it is too expensive to analyze the contents of a provided mask to determine if it is the causal mask.
Original code:
self.bias
, which is passed in as the attention mask in line 55, is defined here:
Torch implementation using mem_efficient kernel (is_causal=True, attn_mask=None):
Since all constraints for BT mem_efficient kernel are satisfied, the following code runs the optimized kernel.
Torch implementation using generic kernel (is_causal=False, original attn_mask)
In this implementation, we keep the original mask but lose the optimized kernel, falling back to the generic implementation.
Results
Training Performance Gain
Training Time (ms) per Batch vs. Sdpa Kernel
)
The above figure demonstrates the performance gained using Pytorch custom kernels. Here are the exact figures:
- baseline (nanoGPT implementation): ~154ms
- sdpa_math (generic): ~147ms (4.54% faster)
-
mem_efficient
kernel: ~130ms (15.58% faster)
Overflow Protection
In addition to being faster, torch’s implementation may also be more numerically stable. There is a great explanation here, but essentially the pytorch implementation scales the Query and Key matrices before multiplication, which is said to be more stable and avoid overflow.
wBoiled down to the essentials, the differences in implementation are noted here:
nanoGPT:
# notice q is not scaled
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
# att = self.attn_dropout(att) # dropout == 0
y_nanogpt = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
pyTorch
# notice q _is_ scaled here
embed_size = q.size(-1)
scaling_factor = math.sqrt(math.sqrt(embed_size))q = q / scaling_factor att = q @ (k.transpose(-2, -1) / scaling_factor)
# same as above
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))att = F.softmax(att0, dim=-1)
# att = self.attn_dropout(att) # dropout == 0
y_scale_before = att0 @ v
Mathematically both approaches should be equivalent, but in our experimentation we find that y_nanogpt
does not equal y_scale_before
.
Additionally, while y_scale_before
was verified as being equivalent to calling _scaled_dot_product_attention
, y_nanogpt
does not match the expected output.
To verify this is an overflow issue, we implemented the provided code and compared the output using torch.allclose
. Demonstrated below:
y_sdpa=torch.nn.functional._scaled_dot_product_attention(
q,
k,
v,
attn_mask=self.bias[:,:,:T,:T] != 0,
dropout_p=0.0,
need_attn_weights=False,
is_causal=False
)
torch.allclose(y_sdpa, y_nanogpt) # False, most likely indicating overflows
torch.allclose(y_sdpa, y_scale_before) # True, as expected
Samples
As an additional sanity check, we trained a model using the original nanoGPT
implementation for around ~450k iterations, and produced sample inferences using both the original and torch sdpa implementations. No noticeable differences are found between both samples:
- Baseline implementation - Output of python sample.py using baseline sdpa - Pastebin.com
- Using torch sdpa- Output of `python sample.py` using `_scaled_dot_product_attention` - Pastebin.com
Environment / Setup
All code was run on a server with (8 x NVIDIA Corporation GA100 [A100 SXM4 80GB])
Environment followed these steps:
- Checkout
nanoGPT
:git clone https://github.com/karpathy/nanoGPT.git
- Install Conda - Installation — conda 23.1.0.post16+5a9deaf5a documentation
- Dependencies:
conda install pytorch pytorch-cuda=11.7 -c pytorch-nightly -c nvidiawith-proxy
conda install numpy
conda install datasets
pip install tiktoken
pip install wandb
conda install tqdm