TL;DR: We’ve implemented a min-cut based recomputation pass with AOTAutograd + NVFuser that consistently improves both memory and runtime across a wide range of models (including the TorchBench suite) for GPU training.
Intro
Recomputation (often called activation checkpointing) is a technique in which, instead of saving some activations for use in backwards, we recompute them during the backwards pass. Thus, we trade off longer runtime for less memory, right?
Actually, we can do better than that. In the presence of a fusing compiler, recomputation can actually reduce both memory and runtime . So, we propose a novel approach based on a min-cut solver to automatically perform recomputation to improve both memory and runtime.
Background
Pointwise operators are what’s known as “bandwidth” bound operators. What this means is that the primary amount of time spent on the operator is not actually executing the operation, but rather memory reads/writes. On GPUs, this is usually reading/writing from global memory. Due to this property, the cost of multiple fused pointwise operators consists of essentially only the memory costs. In other words, “load from memory” => “multiply by itself twice” => “write to memory” is essentially the same runtime as “load from memory” => “multiply by itself once” => “write to memory”.
So, if you have a sequence of pointwise operators in training, then both the forwards pass and the backwards pass consist entirely of pointwise operators, and your runtime is essentially proportional to the amount of memory you’re reading and writing. As such, the typical result of autograd looks something like this:
Instead, we can optimize this by saving only the input to the forward pass, and recomputing the rest in backwards. Now, it looks like this:
So, not only do we reduce the amount of memory saved by half, we’re also performing less memory accesses as well! This allows us to reduce both the runtime and the memory usage.
Decisions … Decisions…
Unfortunately, not all situations are as simple as the one above. For example, take this function
def f(a, b, c, d):
x = a + b + c + d
return x.cos().cos()
That results in a joint graph like the one below.
Here, neither the standard autograd approach (save all of the inputs to the cos calls) nor the approach in our previous note is optimal. The standard autograd approach saves 2 tensors, thus performing 4 activation reads/writes, while the “completely recompute” strategy would end up saving the 4 inputs, which ends up in 4 activation reads (you don’t need to write it from the forwards pass, but you still need to read it during the backwards pass).
Instead, we should save add_2, which is sufficient to compute the backwards pass, and results in only 1 activation read/write. In practice, this results in about a 25% performance improvement.
Here’s another example - there’s a lot of nodes that can’t be recomputed. For example, perhaps they’re too expensive to recompute (matmuls) or they involve randomness. Let’s take this function (which is very similar to dropout):
def f(x):
mask = torch.rand_like(x) < 0.5
return x * x * mask
Once again, we probably want to perform some amount of recomputation. But… we need to make sure that we don’t recompute rand_like
. One possibility is that we simply save rand_like
. In this case, however, we miss an optimization opportunity - instead of saving rand_like
(a FloatTensor), we could save mask
(a BoolTensor), and save 2 bytes per element. In this function, performing this optimization saves 25% of memory and 25% of runtime.
So, how do we account for all these considerations automatically? Sounds like we need some algorithm
Max-Flow/Min-Cut Solution
So, first let’s imagine the case where the only thing we have is pointwise ops. In other words, our entire forwards and backwards graphs are completely fusible. Let’s take the above example again, which is the joint forwards + backwards graph.
The only ops that must be computed in the backwards pass are those that directly depend on the tangents (i.e. the inputs to the backwards pass). This set of ops that must be in the backwards pass can also be called the tangent’s closure . In this case, that’s {mul, mul_1}. Everything else can either be recomputed in the backwards pass or saved from the forwards pass.
So, since all of our ops are pointwise and thus fusible, the only thing that matters for performance is what we’re saving. Moreover, it’s the size of the tensors we’re saving.
So, to restate the problem: Given the joint forwards and backwards graph, what is the minimal set of nodes to “choose” (i.e. save) such that the backwards pass can still be computed from those nodes + the tangents (i.e. backwards pass inputs).
For example, {neg, neg_1} is one possible set of nodes (since you can compute mul and mul_1 from those nodes). Another option is {primals_1, primals_2, primals_3, primals_4}. But in this case, the best solution is {add_2}.
Now, if you’ve done some amount of flow problems, it might be clear to you how this can be formulated as a max-flow/min-cut problem (which can be solved very efficiently)
Essentially, we are trying to find the partition between the source (i.e. input nodes) and the sink (i.e. nodes that must be in the backwards pass) such that the cut (i.e. cost to write/read activations) is minimized. In more concrete terms, add an edge between the source and all input nodes, as well as between all of the nodes in the tangent’s closure and the sink. Then, run a standard max-flow algorithm on this graph, and we have our solution. Ta-da!
Some additional details:
- I thought flow was on edges, not nodes?: There’s a standard transformation from “flow on edges” to “flow on nodes”. Basically, you make all of the existing edges have infinite capacity. Then, you turn all existing nodes into 2 nodes, one that takes all of the incoming edges and another one that takes all of the outgoing edges. Then, you add an edge between these 2 new nodes equivalent to the weight of the node.
- What do you do if we don’t want to recompute some of our nodes? Perhaps they’re compute intensive, or just aren’t fusible : This case is actually pretty easy to solve. Just add another edge from the source to the node you’re trying to recompute! If you do so, then from the perspective of the flow algorithm, that’s just another “input” node. So, if it needs that value, it can either cut it (incurring only the bandwidth cost of that node), or save another value downstream of it which is sufficient. You can also imagine having multiple modes for this algorithm, one which tries to optimize runtime (i.e. only recomputes ops it can fuse), and another that tries to optimize memory (i.e. recomputes all ops except expensive ones).
- What do you actually set as the weight for a node? There are basically 2 cases: 1. If the node would otherwise not be written out to global memory in the forwards pass, then you need to read/write the node twice. 2. Otherwise, if the node already exists in global memory/must be written out to global memory (such as the input to the forwards pass), then the only extra cost of the node is to read it in the backwards pass. So, in case 1 it’s 2 * num_bytes, and in case 2 it’s num_bytes.
- Why the asterisk on optimal? Well, what this pass guarantees is that the memory transfer between forwards and backwards pass is optimally minimized, subject to constraints on what is allowed to be recomputed in the backwards pass. We have currently made the simplifying assumptions that 1. minimizing the memory transfer saved (while not recomputing any unfusible ops) is also optimal for runtime, and 2. we ignore memory usage within each graph. Personally, I think that these assumptions are fairly reasonable (and generally applicable), and my experiments so far seem to bear that out.
There are, however, several somewhat pathological cases. For one, we need to prevent us from recomputing too much of the graph. For example, if we recomputed the entire forwards pass in the backwards pass, then the peak memory usage in our backwards pass hasn’t changed, even though we’re passing very little memory between the forwards and backwards pass. In addition, although fusing ops together is mostly free, if we were to fuse say… a hundred multiply ops (or 10 trigonometric ops), we could potentially find ourselves compute bound again.
So far, these haven’t really been significant issues, as there are usually hard constraints (i.e. don’t recompute matmuls) that prevent us from recomputing inordinately massive graphs. However, I plan to add some kind of heuristic that upweights nodes that imply more recomputation (i.e. are further away from the backwards graph).
- What other approaches have people taken? AFAIK, nobody else has examined recomputation in combination with fusion. If you don’t perform fusion, then, as far as I know, recomputation always runs in strictly more time. In addition, all of the approaches are either heuristic based (like Echo: [1805.08899] Echo: Compiler-based GPU Memory Footprint Reduction for LSTM RNN Training) or requires an ILP solver that runs for hours… (Checkmate: [1910.02653] Checkmate: Breaking the Memory Wall with Optimal Tensor Rematerialization). In constrast, my current flow-based approach runs in <0.4 seconds, even on a fairly large ViT model with ~5000 nodes. I also note that I’m using a flow written in pure Python (networkx), and that this runtime could be reduced even more if needed.
TorchBench Performance Results
First, let’s look at results on TorchBench GPU training. Note that GPU training has always been PyTorch’s primary focus, so this is probably the most difficult setting to improve on TorchBench. Performance is gotten as the average of 20 runs, and memory measures the memory usage at the end of the forwards pass and before the backwards pass starts.
To showcase the flexibility of this approach, we run 2 different configurations of our min-cut recomputation solver. In the first one, called “conservative”, we only allow operators that can be easily fused by NVFuser to be recomputed, such as pointwise operators and reductions. This ensures that we only recompute operators that are easy to fuse, and thus are “free”. In the second one, called “aggressive”, we instead allow all operators to be recomputed, except for compute-intensive operations and operations involving randomness.
As we can see from the results, AOT + MinCut-Conservative performs well across the board in both memory and runtime. It almost never underperforms eager-mode by more than an insignificant margin, and overall nearly triples the average performance improvement compared to Torchscript while also providing 3.5% average memory savings.
On the other hand, AOT + MinCut-Aggressive consistently slashes the memory cost (by 30% on average!), while not incurring a massive runtime cost (only 1.3% on average). Looking closer, though, many of these models do see a 5-10% reduction in performance. In certain cases, such as when memory-bound or doing distributed training, users will happily make this trade. In addition, memory savings can also often be leveraged into better utilization, since it allows larger batch sizes to be used.
Limitations on TorchBench
There are many models in TorchBench that I don’t have results for. For the most part, they either 1. didn’t have a training component, or 2. I couldn’t figure out how to do it easily (like for the RL models…). However, there are a couple of models where AOTAutograd fails on. The two in particular I ran into were fastNLP_Bert
and speech_transformer
, which both convert tensors into scalar values at some point. I couldn’t get some of the other models like this to run (like maskrcnn), but no doubt there are more. There were also a couple of models where the TS lowering failed, usually due to empty lists being parsed as Tensor lists and throwing a type error.
EvoNorm Results
I’ll highlight a couple more performance results.
The first one, the EvoNorm module, highlights the performance improvement coming from a smarter recomputation strategy. EvoNorm (pytorch-image-models/evo_norm.py at bits_and_tpu · rwightman/pytorch-image-models · GitHub) is a new-ish normalization layer that Ross Wightman has been looking at for some time. Previously, Ross had struggled with speeding this up with PyTorch on GPUs:
With NVFuser’s improvements, as well as our min-cut pass, though, we can now speed it up significantly. We try it in 2 different size settings.
(128, 32, 128, 128)
eager: 6617.910861968994
AOT (recompute everything): 2145.9293365478516
AOT (flow-based): 1474
torchscript: 3630.642890930176
(128, 2048, 8, 8)
eager: 1932.8975677490234
AOT (recompute everything): 874.9675750732422
AOT (flow-based): 545
torchscript: 1140.0103569030762
Note that Torchscript with NVFuser already speeds it up significantly (by nearly 2x!). However, even by applying a trivial recomputation strategy where we recompute the entire forwards, we can get 30-50% perfromance improvements. Going beyond and applying our min-cut pass speeds it up even more, and we end up with something nearly 4x faster than eager-mode.
If you’re curious what “magic” the cover photo was referring to, now you know
nn.TransformerEncoder Results
We’ve been looking at speeding up PyTorch’s nn.TransformerEncoder
(along with Natalia) - specifically, the pointwise operators. Previously, we were looking at each operator individually. With this pass, though, we can simply apply the fusion to the entire module.
We can improve both the memory and the runtime by ~11% compared to eager, and the memory by 11% and runtime by 3% compared to Torchscript.
Also, if the users desire a more memory-efficient version, the MinCut-Aggressive pass results in 45% memory savings, while still being about as efficient as Torchscript.
Conclusion
AOTAutograd is still a new API, and we’ve only scratched the surface of the optimizations that AOTAutograd could enable. But… the results so far demonstrate that we can both improve memory usage and runtime performance significantly across a wide range of models in an area that’s been resistant to compilers (yet the most important) - GPU training in PyTorch.
Moreover, the ease in which we implemented this optimization pass makes us very optimistic about the future of AOTAutograd as an extensibility point. The entirety of this (arguably) advanced optimization pass consists of 80 lines of Python (functorch/aot_autograd.py at main · pytorch/functorch · GitHub), and requires no knowledge of PyTorch internals beyond the extensibility point we provide. Furthermore, we’ve talked to several MLSys researchers who are very interested in this extension point, and have expressed excitement to use it
Finally, I’ll note that there is still plenty of easy performance wins left on the board here, that are close to being finished. 1. We’re not fusing any in-place operators, due to an issue with Torchscript’s alias analysis pass and the way AOTAutograd represents its graph. 2. In many cases, I needed to disable NVFuser’s batch norm fusion due to NVFuser errors - enabling them is likely to expose more fusion opportunities 3. NVFuser currently does not fuse views - this prohibits a lot of fusion opportunities in PyTorch code.
These are just the fixes/optimizations already in flight. There are plenty more optimization opportunities that we’re looking at, ranging from integrating CUDA Graphs, TASO-style graph optimizations, or automatic offloading.