Making Transformer inference faster on GPUs

TL;DR - if you’re doing GPU inference with models using Transformers in PyTorch, and you want to a quick way to improve efficiency, you could consider calling transformer = NVFasterTransformer(old_transformer) or similar. You can expect large improvements (~4x) in small-batch, variable-sequence-length cases, and smaller improvements (~1.4x) in large-batch, large-sequence-length cases.

Transformer Inference on GPUs

Generally speaking, running Transformers on GPUs efficiently is fairly simple. The majority of computation is taken by large dense matrix multiplications which run at high efficiency on GPUs, we can trivially run in float16 to take advantage of tensor cores, and so for large-batch sizes things work great, at >70% of peak CuBLAS FLOPs.

Where things get a little more challenging is in the inference setting. Here, we have some different challenges:

  1. Running large batch sizes is infeasible due to latency constraints, requiring small or medium-sized batches.
  2. With varying input length distributions, we waste a significant amount of computation on the padded entries.

For the first problem, a naive GPU Transformer implementation has the problem that we become kernel latency launch bound at small batch sizes, with a typical trace having lots of gaps in the GPU stream. One trick for fixing this is to apply kernel fusion and merge various kernels together, to ameliorate the ~10us kernel launch latency. There are a few dozen implementations of fully-fused CUDA transformer implementations lying around in OSS, such as FasterTransformer, effective_transformer, TurboTransformers, DeepSpeed, etc.

These allow us to become card-bound at low batch sizes, which is nice. These libraries also implement some other minor tweaks such as pre-transposing weights to help CuBLAS and so on.

For the second problem, it’s a little more complex and needs some elaboration. The typical approach for handling variable size inputs (e.g. a batch of B tokens, each of length T_b ), is to stack them into a tensor of size (B, T_max) , adding padding if necessary.

When we have a large divergence between T_avg and T_max (e.g. T_max = 256 , T_avg = 64 ) we’d expect a significant amount of wasted computation (~4x in that case). Thus, we have a tradeoff - increasing batches is good for efficiency, but increasing batches also increases the amount of wasted computation (in expectation). There are some approaches to reconcile this involving bucketing and batching within a sequence-length bucket, but these all involve some unfortunate tradeoffs and introduce non-trivial complexity.

There is an alternative method to handling padding, which was introduced in the CuDNN variable sequence length RNN implementations, and re-popularized by ByteDance in effective_transformer. Intuitively, we just convert the input into a CSR representation (so a dense (sum(T_b), ...) tensor instead of a padded (B, T_max, ...) tensor, run the majority of our computations on the CSR representation, and pessimize back to the padded representation when required (for e.g. the batch matmul in attention computation).

This approach resolves these two problems for us - we can batch up arbitrarily large batches (subject to latency constraints), and run them efficiently without wasting compute. We can preserve a completely identical API, and just swap out the Transformer nn.Module with an alternative implementation that uses this approach under the hood.

This approach is implemented in FasterTransformer, effective_transformer, TurboTransformers, and I’m sure many others.


This ended up being a fairly trivial affair (which is nice!). We just implemented a shim that remapped the weights from the PyText representation to the FasterTransformer representation, implemented a module swapping function, and fixed a numerical difference between our PyText implementation and the FasterTransformer implementation. All the code is in the PyText repo in

Standalone Benchmarks

Just running the standalone transformer implementations at various settings (batch sizes, average sequence length, maximum sequence length) allows us to separate the impact of these two changes (kernel fusion, and efficient padding handling). The blue PyText values correspond to the baseline, the purple EffectiveTransformer implementation refers to an implementation where we use the existing PyText GPU kernels but implement the pad/unpad EffectiveTransformer trick, and the red FasterTransformer values refers to our implementation using NVIDIA’s FasterTransformer. In the non-padded case ( B=64, T=256, TMax=256 ), we get roughly a 1.5x speedup on V100 and roughly a 2x speedup on A100, which is purely from kernel fusion and other optimizations in the FasterTransformer stack. In the padded case ( B=64, T=64, TMax=256 ), we see roughly a 4x improvement on V100 (roughly equal to amount of “wasted” computation we do in the baseline, and close to our E2E system results), and roughly 6x on the A100. We hit fairly close to device capability for these models - roughly 75-80% of CuBLAS peak in the compute-bound case for both GPUs.

Next Steps

There is some even more low-hanging fruit here. When we compute the batch matrix multiplication, we have to pad up to the original sequence length, which means we are still wasting some chunk of compute. Avoiding this requires an implementation of segmented batched matrix multiplication which NVIDIA has already done for us (qkvToContext in TensorRT), and we see around 1.5x improvements from TensorRT in preliminary benchmarks in the padded scenarios (and ~1.1x improvements in the compute-bound scenarios).

Of course, substantially increasing efficiency allows us to run larger and better models with the same capacity footprint, so that’s a direction the team will be exploring.

More and more people are switching to transformer-based models - across Speech, NLP, NMT, Vision, and so on. For efficiency purposes, it’s important that we make the most of our hardware, and the existing GPU ecosystem makes that fairly straightforward. It’s important not to overfit our platforms and software to specific architectures or modeling choices, which is why we should continue to improve the baseline PyTorch performance and ideally automate these improvements (through JIT, NNC, NvFuser, etc), but for building blocks as common as the standard BERT Transformer, its justifiable to spend a bit of effort to make them as fast as possible.