TL;DR
- Compiled Adam outperformed SOTA hand-optimized APEX optimizers on all benchmarks; 62.99% on Torchbench, 53.18% on HuggingFace, 142.75% on TIMM and 88.13% on BlueBerries
- Compiled AdamW performed similarly with up to a 2x improvement on TIMM
- Compiled SGD had at least a 30% speedup on all benchmarks.
- Some models are particularly sensitive to overheads, for these models the cpp_wrapper or cudagraphs can improve performance by up to 2x or more.
Background
Currently, torch.compile supports compiling 11/13 optimizers into optimized foreach kernels. In order to drive adoption it is necessary to show how compiled optimizers can reach SOTA performance, outperforming APEX, the NVidia fused optimizer kernel library which is the main performant alternative to PyTorch optimizers. APEX supports three optimizers in common with torch.compile; SGD, AdamW, and Adam. These optimizers are evaluated on models from HuggingFace, TorchBench, and TIMM benchmarks.
Speedup Results (Raw Data) (script)
Analysis
The different test suites had varying performance characteristics due to model size. Since APEX and torch.compiled kernels are so fast, TorchBench models along with smaller TIMM models were particularly sensitive to overheads. For example, the profile for compiled Adam optimizer on dcgan is shown below.
Figure 1: dcgan torch.compile profile
In this profile it is almost entirely dominated by overhead; the triton kernel launches are extremely small in the lower right hand corner. After further investigation, the overhead was due to three components: 1) guarding 2) cudagraphs checks 3) cudagraph launches.
Guards are conditions that torch.compile checks which must be run on each invocation in order to check that the compiled artifact is valid. In order to lower this overhead, three approaches were used. To improve guard performance in general, Animesh Jain moved guard evaluation to C++. On top of this, since optimizer state tensors do not change across iterations, tensor shape guards were switched to guards on the data pointers, which is much faster. Finally, there were redundant guards due to the repetitive structure of the optimizer that were able to be removed entirely.
For 2) cudagraphs checks the data pointers of parameter tensors are static across invocations in order to guarantee the correctness of the recorded cudagraph. These checks were initially implemented in python but after moving to C++, performance was improved significantly.
Initially, this overhead resulted in compiled Adam being initially 20% slower on torchbench than APEX even with cudagraphs, but after the above changes, compiled Adam outperformed APEX by up to 2x on TIMM. Overall, Huggingface appeared to give the most realistic results, as those kernels operated on tensors large enough that the kernel runtime begins to dominate the total runtime, which is the more typical scenario for mainstream models. To illustrate this, the profile for XGLMForCausalLM from HuggingFace is shown below.
Figure 2: XGLMForCausalLM torch.compile profile
Note on SGD
SGD is an interesting test case for overhead comparison with Eager. Because SGD is a single memory-bound kernel in Eager, there are not any vertical fusion optimization opportunities, which is illustrated by lack of speedup of ~.96 on Torchbench and HuggingFace with generic Torch.Compile.
Why is Torch.Compile faster?
As APEX optimizers are highly optimized, it is important to address the question of why torch.compile is faster. There could be a multitude of reasons related to python overhead, (torch.compile traces this away), better code generation from triton, or other optimizations. Below are the profiles generated for XGLMForCausalLM for torch.compile and APEX respectively.
Figure 3: Torch.Compile XGLMForCausalLM Profile
Figure 4: APEX XGLMForCausalLM Profile
From these profiles, the answer to this question is that it is a combination of both overhead reduction (by ~2ms) and the kernels themselves being faster by about 4ms due to better performance of the triton code.
TIMM Performance
TIMM performance was surprisingly slow without cpp wrapper or cudagraphs. For example on ghostnet_100 adding the cpp wrapper improves performance by 2x. The profile is shown below.
Figure 5: ghostnet_100 torch.compile profile with launch overhead
Looking at this profile, there are obvious gaps between kernel launches. This is due to extracting argument pointers and launch overhead from the inductor-generated python code which launches the generated triton kernels. This problem is unique to foreach kernels because they have a much larger number of arguments than the typical kernel in order to maximize any available memory bandwidth. In addition to having a large number of arguments, the arguments themselves are very small, resulting in kernels that run so fast that they finish before the next kernel can be launched, decreasing GPU utilization.
After enabling the cpp wrapper, inductor generates the launcher code in C++ instead of python, which lowers this overhead significantly, resulting in the following.
Figure 6: ghostnet_100 torch.compile profile w/ cpp_wrapper
This drastically reduces the launch overhead per kernel, resulting in a > 2x speedup.
Overall with these optimizations the compiled optimizers outperform the APEX optimizers on HuggingFace, TIMM, and TorchBench.
Next Steps
- Performance can be improved further with autotuning tile sizes
- LRScheduler compatibility with torch.compile’d optimizers
Acknowledgements
Thanks to Animesh Jain for guard optimizations, Elias Ellison for help with optimizing cudagraphs, Jason Ansel for design feedback, and Jane Xu for reviewing this post.