TorchDynamo Update 6: Training support with AOTAutograd

Recap

We are working on an experimental project called TorchDynamo. TorchDynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. TorchDynamo hooks into the frame evaluation API in CPython (PEP 523) to dynamically modify Python bytecode right before it is executed. It rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with an ensemble of different backends and autotuning. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends to get the best of both worlds: usability and performance.

If you are new here the TorchDynamo README is a good place to start, you can also catch up on our prior posts:

Adding Training in TorchDynamo

The biggest change since last time has been adding training support with AOTAutograd.

Training adds challenges because the PyTorch Automatic Differentiation engine sits below the PyTorch dispatcher in C++. Therefore, the operators running in the backward pass are not directly visible to TorchDynamo at the Python level.

To support training with TorchDynamo, we need to run/optimize operations that happen in the .backward() pass. Supporting backwards can be done in a few different ways:

  • Eagerly: backends could use the dynamic autograd tape on every call, the same as eager mode.
  • TorchScript: is a hybrid. It runs some things eagerly using the dynamic tape, and for others has a separate implementation of many autograd formulas. This is difficult to maintain and does not support all operations.
  • AOTAutograd: records the behavior of the eager dispatcher-based autograd once at compile time. This allows it to capture everything with a simpler and more robust implementation that reuses much more of eager mode autograd. In addition, it allows us to easily make optimization decisions with visibility of both the forwards and backwards graphs.

AOTAutograd

AOTAutograd relies on the recently introduced torch_dispatch based tracing mechanism to capture the backward graph ahead of time. Therefore, it reuses the PyTorch core Autograd engine to generate the backward graph. This is a benefit over the parallel implementation of symbolic autodiff in TorchScript, which is difficult to maintain over time.

Additionally, AOTAutograd design allows easier joint forward-backward graph optimizations such as activation checkpointing (please refer to this post for more details). This results in two separate graphs - one for forward and other for backward. We then compile these graphs separately using the backend compilers. In this post, AOTAutograd uses TorchScript with NNC/nvFuser for compiling the generated forward and backward graphs.

Results

We tested this integration on TorchBench models on NVIDIA A100 GPUs. Our training measurement is forward() + loss calculation + backward(), where loss calculation is just mean() as a placeholder for measurement. This is not full training because it does not include the optimizer. We also set the models in eval mode. This helps us in removing randomness from operations like Dropout and performing accuracy tests. This is done so that we can verify accuracy and maintain confidence in the correctness of our benchmarks, although there are some divergences with actual training models. For example, batch norm is now fusible as a pointwise operator, while dropout is no longer fusible.

We check the numerical accuracy by comparing the computed gradients. We measure both latency and peak memory footprint of the training iteration. The table below shows speedup and memory savings of different configurations normalized to eager performance. For AOTAutograd, we use the min-cut recompuation algorithm as discussed in this post.

Some of the observations are

  • The first two columns directly use TorchScript to run the full model training iteration. Many models fail to run correctly. This is where TorchDynamo helps (later columns). TorchDynamo finds subgraphs that are more amenable for scripting and increases the coverage from 55% to 95% with NNC backend.
  • Both TorchScript and AOTAutograd achieve good speedups for many models. For example, AOTAutograd with nvFuser speeds up timm_vision_transfomer by 1.30x. On average, AOTAutograd speeds up the TorchBench models by 1.09x.
  • AOTAutograd demonstrates good reduction in the activations saved memory by using activation checkpointing (aka recomputation), such as 1.55x savings.
  • For TorchDynamo + TorchScript + nvFuser case (no AOTAutograd), it seems that there are many failures. However, there are only a couple of issues which cause many models to fail. These are known issues with nvFuser that the nvFuser team is looking into.
  • For the most part, AOTAutograd + NVFuser has both the best performance and memory usage across all configurations tested. However, there are a couple of exceptions, primarily due to issues mentioned below.

Outstanding Issues

There is still some work remaining to get the accuracy passing on all the TorchBench models. In the table above, we see 4 models failing. There are a few more models skipped here.

The failures are spread across different components (TorchDynamo, AOTAutograd, TorchBench, TorchScript and nvFuser). The running list of these issues is here. Because there are many components here, this integration exercise has revealed bugs/issues across the components. Some of them are

  • AOTAutograd

    • AOTAutograd can result in incorrect behavior if mutation is present in the graph. Brian Hirsh is already working on a functionalization pass which can resolve this problem (github issue).
    • Contiguous tensors - AOTAutograd traces the backward pass assuming that the outputs of the forward pass and the corresponding gradients for the backward pass have the same strides. However, there is no such guarantee. Currently, we have a suboptimal solution, where we force the outputs and input gradients to be contiguous in AOTAutograd (issue).
    • Overhead - The analysis shows that AOTAutograd shows suboptimal performance for some models, such as speech_transformer (which creates 55(!) different TorchDynamo subgraphs). We suspect that part of this is related to additional overhead of AOTAutograd.
  • TorchScript

    • Scripting fails due to missing operator support (issue)
  • nvFuser

    • nvFuser recently started fusing view and reshape operators. This caused a temporary regression where many TorchBench models started failing. The nvFuser team has responded quickly to these issues, and have fixed many of them in their latest code bump. For now, we have disabled the fusion of view operators.
    • There are some models like soft_actor_critic and drq, where it seems that nvFuser is performing suboptimally. For example, in soft_actor_critic AOT+NNC performs on par with eager while AOT+NVFuser performs worse. Note that this is not due to NVFuser obtaining smaller fusion groups from AOTAutograd, in this case AOT+NVFuser results in 6 fusion groups containing 26 ops, while TS+NVFuser results in 5 fusion groups containing 15 ops. We will investigate further and discuss this with the nvFuser team.

Next Steps

While there are still many outstanding issues, these results give us confidence that AOTAutograd+TorchDynamo can deliver speedups for training. Looking further ahead, there are other bigger/complex topics like supporting dynamic shapes and distributed training. These are dependent on ongoing efforts in PyTorch, but we are incredibly optimistic about recent progress there and look forward to an exciting future!

This is a joint collaboration with @jansel and @Chillee

9 Likes