We are excited to announce that over the course of the first half of 2024 we have been prioritizing improving compile times for torch.compile
workflows. Swift iterations and efficient development cycles are very important for us and as to that end, we have been planning and executing on ways to bring compile time down to zero.
Today, I’d like to give a bit of a preview on all the work we have planned and are actively executing on. For any feature mentioned here, please use the most recent nightly PyTorch build in order to get the best experience. Please keep in mind that these strategies are still works in progress. As these features mature and are ready for production use, we will announce them broadly via the PyTorch release process.
Let’s break down torch.compile’s compile times. There are three main parts of torch.compile
: Dynamo, AOT Autograd, and Inductor. In most of our training benchmarks, we usually observe
- Dynamo: 5-10% of compilation
- AOT Autograd: 40-45% of compilation
- Inductor (Code generation, autotuning, triton kernel compilation): 50% of compilation
This means that if we could cache the compilation artifacts of AOT Autograd and Inductor, we could realistically reduce warm compilation times by 90+%. This means that we can address the remaining compilation time by improving TorchDynamo’s performance by good old performance work.
To this end, here’s the roadmap for us.
Cold & Warm Start Improvements - Caching In Many Parts
Foremost, let’s identify when to expect the caching behavior. As a rule of thumb, the caching will kick in when the FX Graph provided by TorchDynamo is identical to the previous iteration. This can be verified by using TORCH_LOGS=+dynamo
. If the FX Graph is modified, then the amount of caching will vary based on the particular modification.
- Dynamo
Around the tail end of last year, as well as the beginning of this year, we have built improvements in FakeTensor caching within the same run, and we continue to improve this strategy through this year. This cache can be enabled via TORCH_FAKE_TENSOR_DISPATCH_CACHE=1
. It is already turned on by default.
Another idea we have been toying around is hierarchical compilation. Suppose there’s some looping primitive and inside this looping primitive, dynamo currently performs unrolling which results in the body of this looping primitive to be inlined and thus compiled multiple times. This also applies to nn.Modules
. The goal of hierarchical compilation is to take advantage of the information from the previously compiled loop body.
- AOTAutograd
We’ve begun extending our work on FXGraphCache to a new caching layer in AOTAutograd. To do so, we construct a safe cache key from the set of inputs to autograd and global configurations. We’ve also refactored AOTAutograd into a set of clear pre- and post- compilation phases, which allows AOTAutograd to easily save and reconstruct the wrapped runtime callable it produces. This will allow us to avoid doing joint graph analyses, autograd, partitioning, and compilation in most warm start scenarios.
This work is still in progress and currently not yet ready to be used.
- Inductor - FX Graph Compilation
Given the same FX Graph and same configuration, we should be able to reuse the existing compilation artifacts. This cache can be enabled via TORCHINDUCTOR_FX_GRAPH_CACHE=1
.
- Inductor - Autotuning Cache
Inductor emits Triton kernels that need to be autotuned in order to select the best BLOCK_SIZE. This cache is also enabled by default.
- Triton Kernel Compilation
We have worked with the Triton team at OpenAI to implement caching to improve compilation times. This cache is also enabled by default.
Going From Local To Remote
Local caches are very effective for small models but for larger models running on many machines, remote caches can provide larger benefits.
Caches implemented in Inductor have historically been targeting on disk or in memory caching, moving from local caching strategies to remote will bring the cold-start behavior only happening once globally and rest of the machines to always have warm start behaviors. We suspect that this, in majority of the cases, will allow for overall compilation reduction as opposed to only on the latter runs of the same job.
Remote caching of course brings in more challenges, such as making sure the cached artifacts are portable between platforms, and that cache key correctly contains all hardware and source code information (in our case, version of PyTorch compiler source code, GPU capabilities and Triton source as all of these are sources that could change without any versioning).
For all the caching we mentioned above, we also built remote versions that use a Redis backend. These remote versions are still being worked on, feel free to reach out to us on Github if you or your company would like to get a head start on testing it out.
Disabling Caches
In order to measure the cold start compilation time or debug a cache corruption, it is possible to pass TORCHINDUCTOR_FORCE_DISABLE_CACHES=1
which will override any other caching config option and disable all compile time caching.