Min-cut optimal(*) recomputation (i.e. activation checkpointing) with AOTAutograd

TL;DR: We’ve implemented a min-cut based recomputation pass with AOTAutograd + NVFuser that consistently improves both memory and runtime across a wide range of models (including the TorchBench suite) for GPU training.


Recomputation (often called activation checkpointing) is a technique in which, instead of saving some activations for use in backwards, we recompute them during the backwards pass. Thus, we trade off longer runtime for less memory, right?

Actually, we can do better than that. In the presence of a fusing compiler, recomputation can actually reduce both memory and runtime . So, we propose a novel approach based on a min-cut solver to automatically perform recomputation to improve both memory and runtime.


Pointwise operators are what’s known as “bandwidth” bound operators. What this means is that the primary amount of time spent on the operator is not actually executing the operation, but rather memory reads/writes. On GPUs, this is usually reading/writing from global memory. Due to this property, the cost of multiple fused pointwise operators consists of essentially only the memory costs. In other words, “load from memory” => “multiply by itself twice” => “write to memory” is essentially the same runtime as “load from memory” => “multiply by itself once” => “write to memory”.

So, if you have a sequence of pointwise operators in training, then both the forwards pass and the backwards pass consist entirely of pointwise operators, and your runtime is essentially proportional to the amount of memory you’re reading and writing. As such, the typical result of autograd looks something like this:

Instead, we can optimize this by saving only the input to the forward pass, and recomputing the rest in backwards. Now, it looks like this:

So, not only do we reduce the amount of memory saved by half, we’re also performing less memory accesses as well! This allows us to reduce both the runtime and the memory usage.

Decisions … Decisions…

Unfortunately, not all situations are as simple as the one above. For example, take this function

def f(a, b, c, d):
    x = a + b + c + d
    return x.cos().cos()

That results in a joint graph like the one below.

Here, neither the standard autograd approach (save all of the inputs to the cos calls) nor the approach in our previous note is optimal. The standard autograd approach saves 2 tensors, thus performing 4 activation reads/writes, while the “completely recompute” strategy would end up saving the 4 inputs, which ends up in 4 activation reads (you don’t need to write it from the forwards pass, but you still need to read it during the backwards pass).

Instead, we should save add_2, which is sufficient to compute the backwards pass, and results in only 1 activation read/write. In practice, this results in about a 25% performance improvement.

Here’s another example - there’s a lot of nodes that can’t be recomputed. For example, perhaps they’re too expensive to recompute (matmuls) or they involve randomness. Let’s take this function (which is very similar to dropout):

def f(x):
  mask = torch.rand_like(x) < 0.5
  return x * x * mask

Once again, we probably want to perform some amount of recomputation. But… we need to make sure that we don’t recompute rand_like . One possibility is that we simply save rand_like . In this case, however, we miss an optimization opportunity - instead of saving rand_like (a FloatTensor), we could save mask (a BoolTensor), and save 2 bytes per element. In this function, performing this optimization saves 25% of memory and 25% of runtime.

So, how do we account for all these considerations automatically? Sounds like we need some algorithm :slight_smile:

Max-Flow/Min-Cut Solution

So, first let’s imagine the case where the only thing we have is pointwise ops. In other words, our entire forwards and backwards graphs are completely fusible. Let’s take the above example again, which is the joint forwards + backwards graph.

The only ops that must be computed in the backwards pass are those that directly depend on the tangents (i.e. the inputs to the backwards pass). This set of ops that must be in the backwards pass can also be called the tangent’s closure . In this case, that’s {mul, mul_1}. Everything else can either be recomputed in the backwards pass or saved from the forwards pass.

So, since all of our ops are pointwise and thus fusible, the only thing that matters for performance is what we’re saving. Moreover, it’s the size of the tensors we’re saving.

So, to restate the problem: Given the joint forwards and backwards graph, what is the minimal set of nodes to “choose” (i.e. save) such that the backwards pass can still be computed from those nodes + the tangents (i.e. backwards pass inputs).

For example, {neg, neg_1} is one possible set of nodes (since you can compute mul and mul_1 from those nodes). Another option is {primals_1, primals_2, primals_3, primals_4}. But in this case, the best solution is {add_2}.

Now, if you’ve done some amount of flow problems, it might be clear to you how this can be formulated as a max-flow/min-cut problem (which can be solved very efficiently) :slight_smile:

Essentially, we are trying to find the partition between the source (i.e. input nodes) and the sink (i.e. nodes that must be in the backwards pass) such that the cut (i.e. cost to write/read activations) is minimized. In more concrete terms, add an edge between the source and all input nodes, as well as between all of the nodes in the tangent’s closure and the sink. Then, run a standard max-flow algorithm on this graph, and we have our solution. Ta-da!

Some additional details:

  • I thought flow was on edges, not nodes?: There’s a standard transformation from “flow on edges” to “flow on nodes”. Basically, you make all of the existing edges have infinite capacity. Then, you turn all existing nodes into 2 nodes, one that takes all of the incoming edges and another one that takes all of the outgoing edges. Then, you add an edge between these 2 new nodes equivalent to the weight of the node.
  • What do you do if we don’t want to recompute some of our nodes? Perhaps they’re compute intensive, or just aren’t fusible : This case is actually pretty easy to solve. Just add another edge from the source to the node you’re trying to recompute! If you do so, then from the perspective of the flow algorithm, that’s just another “input” node. So, if it needs that value, it can either cut it (incurring only the bandwidth cost of that node), or save another value downstream of it which is sufficient. You can also imagine having multiple modes for this algorithm, one which tries to optimize runtime (i.e. only recomputes ops it can fuse), and another that tries to optimize memory (i.e. recomputes all ops except expensive ones).
  • What do you actually set as the weight for a node? There are basically 2 cases: 1. If the node would otherwise not be written out to global memory in the forwards pass, then you need to read/write the node twice. 2. Otherwise, if the node already exists in global memory/must be written out to global memory (such as the input to the forwards pass), then the only extra cost of the node is to read it in the backwards pass. So, in case 1 it’s 2 * num_bytes, and in case 2 it’s num_bytes.
  • Why the asterisk on optimal? Well, what this pass guarantees is that the memory transfer between forwards and backwards pass is optimally minimized, subject to constraints on what is allowed to be recomputed in the backwards pass. We have currently made the simplifying assumptions that 1. minimizing the memory transfer saved (while not recomputing any unfusible ops) is also optimal for runtime, and 2. we ignore memory usage within each graph. Personally, I think that these assumptions are fairly reasonable (and generally applicable), and my experiments so far seem to bear that out.

There are, however, several somewhat pathological cases. For one, we need to prevent us from recomputing too much of the graph. For example, if we recomputed the entire forwards pass in the backwards pass, then the peak memory usage in our backwards pass hasn’t changed, even though we’re passing very little memory between the forwards and backwards pass. In addition, although fusing ops together is mostly free, if we were to fuse say… a hundred multiply ops (or 10 trigonometric ops), we could potentially find ourselves compute bound again.

So far, these haven’t really been significant issues, as there are usually hard constraints (i.e. don’t recompute matmuls) that prevent us from recomputing inordinately massive graphs. However, I plan to add some kind of heuristic that upweights nodes that imply more recomputation (i.e. are further away from the backwards graph).

TorchBench Performance Results

First, let’s look at results on TorchBench GPU training. Note that GPU training has always been PyTorch’s primary focus, so this is probably the most difficult setting to improve on TorchBench. Performance is gotten as the average of 20 runs, and memory measures the memory usage at the end of the forwards pass and before the backwards pass starts.

To showcase the flexibility of this approach, we run 2 different configurations of our min-cut recomputation solver. In the first one, called “conservative”, we only allow operators that can be easily fused by NVFuser to be recomputed, such as pointwise operators and reductions. This ensures that we only recompute operators that are easy to fuse, and thus are “free”. In the second one, called “aggressive”, we instead allow all operators to be recomputed, except for compute-intensive operations and operations involving randomness.

As we can see from the results, AOT + MinCut-Conservative performs well across the board in both memory and runtime. It almost never underperforms eager-mode by more than an insignificant margin, and overall nearly triples the average performance improvement compared to Torchscript while also providing 3.5% average memory savings.

On the other hand, AOT + MinCut-Aggressive consistently slashes the memory cost (by 30% on average!), while not incurring a massive runtime cost (only 1.3% on average). Looking closer, though, many of these models do see a 5-10% reduction in performance. In certain cases, such as when memory-bound or doing distributed training, users will happily make this trade. In addition, memory savings can also often be leveraged into better utilization, since it allows larger batch sizes to be used.

Limitations on TorchBench

There are many models in TorchBench that I don’t have results for. For the most part, they either 1. didn’t have a training component, or 2. I couldn’t figure out how to do it easily (like for the RL models…). However, there are a couple of models where AOTAutograd fails on. The two in particular I ran into were fastNLP_Bert and speech_transformer, which both convert tensors into scalar values at some point. I couldn’t get some of the other models like this to run (like maskrcnn), but no doubt there are more. There were also a couple of models where the TS lowering failed, usually due to empty lists being parsed as Tensor lists and throwing a type error.

EvoNorm Results

I’ll highlight a couple more performance results.

The first one, the EvoNorm module, highlights the performance improvement coming from a smarter recomputation strategy. EvoNorm (pytorch-image-models/evo_norm.py at bits_and_tpu · rwightman/pytorch-image-models · GitHub) is a new-ish normalization layer that Ross Wightman has been looking at for some time. Previously, Ross had struggled with speeding this up with PyTorch on GPUs:

With NVFuser’s improvements, as well as our min-cut pass, though, we can now speed it up significantly. We try it in 2 different size settings.

(128, 32, 128, 128)
eager: 6617.910861968994
AOT (recompute everything): 2145.9293365478516
AOT (flow-based): 1474
torchscript: 3630.642890930176

(128, 2048, 8, 8)
eager: 1932.8975677490234
AOT (recompute everything): 874.9675750732422 
AOT (flow-based): 545
torchscript: 1140.0103569030762

Note that Torchscript with NVFuser already speeds it up significantly (by nearly 2x!). However, even by applying a trivial recomputation strategy where we recompute the entire forwards, we can get 30-50% perfromance improvements. Going beyond and applying our min-cut pass speeds it up even more, and we end up with something nearly 4x faster than eager-mode.

If you’re curious what “magic” the cover photo was referring to, now you know :slight_smile:

nn.TransformerEncoder Results

We’ve been looking at speeding up PyTorch’s nn.TransformerEncoder (along with Natalia) - specifically, the pointwise operators. Previously, we were looking at each operator individually. With this pass, though, we can simply apply the fusion to the entire module.

We can improve both the memory and the runtime by ~11% compared to eager, and the memory by 11% and runtime by 3% compared to Torchscript.

Also, if the users desire a more memory-efficient version, the MinCut-Aggressive pass results in 45% memory savings, while still being about as efficient as Torchscript.


AOTAutograd is still a new API, and we’ve only scratched the surface of the optimizations that AOTAutograd could enable. But… the results so far demonstrate that we can both improve memory usage and runtime performance significantly across a wide range of models in an area that’s been resistant to compilers (yet the most important) - GPU training in PyTorch.

Moreover, the ease in which we implemented this optimization pass makes us very optimistic about the future of AOTAutograd as an extensibility point. The entirety of this (arguably) advanced optimization pass consists of 80 lines of Python (functorch/aot_autograd.py at main · pytorch/functorch · GitHub), and requires no knowledge of PyTorch internals beyond the extensibility point we provide. Furthermore, we’ve talked to several MLSys researchers who are very interested in this extension point, and have expressed excitement to use it :slight_smile:

Finally, I’ll note that there is still plenty of easy performance wins left on the board here, that are close to being finished. 1. We’re not fusing any in-place operators, due to an issue with Torchscript’s alias analysis pass and the way AOTAutograd represents its graph. 2. In many cases, I needed to disable NVFuser’s batch norm fusion due to NVFuser errors - enabling them is likely to expose more fusion opportunities 3. NVFuser currently does not fuse views - this prohibits a lot of fusion opportunities in PyTorch code.

These are just the fixes/optimizations already in flight. There are plenty more optimization opportunities that we’re looking at, ranging from integrating CUDA Graphs, TASO-style graph optimizations, or automatic offloading.


Thanks for the proposal, which looks exciting!

After reading the proposal, I have a few questions:

  1. In the first figure in Background, I suppose the read/write here means the communication between GPU global memory and shared memory / registers (please correct me if this is wrong). My general understanding to the runtime is that we have to write output tensors back to the global memory after each individual op, so that it can be managed by the tensor manager and become the input of the next op. With this understanding, the old forward pass should have 3 memory read and 3 memory writes; the new backward pass should have 7 memory reads and 4 memory writes.
    However, from this figure, the old forward pass only takes 1 memory read and 3 memory writes, so it seems like S2 and S1 could stay in the shared memory and directly be used by the next op. Does that mean the output tensor is not necessary to be written back to global memory? If so I’m wondering how this is achieved.

  2. In the TorchBench Performance Results, could you explain the meaning of numbers in the table? Taking the first row as an example, TS: Mem, Runtime = (93.92%, 117.67%) while AOTConservative: Mem, Runtime = (101.6%, 111.11%). Does that mean the AOT one reduces 6.56% memory with 7.68% performance overhead? Also what’s the baseline (i.e., 100%) of this?

  3. In the general recomputation strategy, you proposed that recomputable ops should be the bandwidth bound ops, which are mostly pointwise ops. This strategy makes sense to me, but AFAIK, there are also some compute-intensive pointwise ops, so a list of recomputable ops might be shorter than we expected. In this case, I’m wondering how large the optimization opportunity would be, but I guess we could just perform more evaluations to answer this question :slight_smile:

1 Like

However, from this figure, the old forward pass only takes 1 memory read and 3 memory writes, so it seems like S2 and S1 could stay in the shared memory and directly be used by the next op.

This is in the context of fusing compilers. That is, a fusing compiler (like Torchscript) could fuse these operators together and elide the extra reads from global memory. That is, you can kind of think of the flow as

x1: SharedMemory = x.sigmoid()
x2: SharedMemory = x1.sigmoid()
x3: SharedMemory = x2.sigmoid()
s2: GlobalMemory = writeToGlobal(x1)
s1: GlobalMemory = writeToGlobal(x2)
out: GlobalMemory = writeToGlobal(x3)
return out, s2, s1

Does that mean the AOT one reduces 6.56% memory with 7.68% performance overhead? Also what’s the baseline (i.e., 100%) of this?

Right, for this case, compared to TS. I think this is the only model where AOTAutograd performs worse than Torchscript (should have hidden this one in the middle of the table :P). The values are all in comparison to eager, and are essentially eager_mem/AOT_mem and eager_runtime/AOT_runtime. So eager is 100%.

This strategy makes sense to me, but AFAIK, there are also some compute-intensive pointwise ops, so a list of recomputable ops might be shorter than we expected.

Yeah, you can see the ops we recompute here: https://github.com/pytorch/functorch/blob/main/functorch/_src/aot_autograd.py#L205

In this case, I’m wondering how large the optimization opportunity would be

Well, I think these results give a pretty good sense. Geomean of 6% for GPU training is pretty good imo :stuck_out_tongue: Although it certainly depends on your model. If you have a simple MLP with matmuls followed by relus, there’s not going to be much of an opportunity to optimize it.

1 Like

Thanks for the clarification!

It makes a lot of sense to me now if all figures in Background considers fusion. Accordingly, a design problem comes in for this case: When modeling the read/write cost to build the graph for min-cut algorithm, we make an assumption that the op will be fused so that the write cost can be eliminated. However, it also means that this cost would highly depend on the NVFuser behavior and has to perfectly align to it. For example, suppose now we add a new feature to NVFuser that won’t create a fused op with more than 10 ops. Then we have to change the cost computation logic accordingly.

For the result table, sorry I didn’t notice that the title already mentions the improvements are over Eager…Also I would suggest switching the equation to either AOT/Eager, or 1 - Eager/AOT (both are lower the better). IMHO, it might be more straightforward.

Finally, yeah I agree that the presented results deliver a good sense (after I understood how to read them correctly :stuck_out_tongue_winking_eye:).

However, it also means that this cost would highly depend on the NVFuser behavior and has to perfectly align to it.

It’s not too bad I think - experimentally it seems to work well :slight_smile:

Certainly there are limitations to this approach, and areas where our simplifications don’t align with reality. But for our use case, I think this approach strikes a nice balance between 1. optimality, 2. simplicity, and 3. speed, which is a good fit for practical use cases imo.

We’ve also seen substantial improvement from the use of min-cut and other recomputation heuristics on the GPU in Enzyme (https://dl.acm.org/doi/pdf/10.1145/3458817.3476165 see “Recompute versus Cache Heuristics” on page 7 for the description, and Figure 10 for the performance impact, or alternatively slides 31-33 https://c.wsmoses.com/presentations/enzyme-sc.pdf#31).

For the min-cut in particular, we took a slightly different approach using representing an instruction two graph-nodes (to allow for two types of edges, a value use edge, and an edge to represent the use – which makes it easy to nicely handle values with multiple uses, etc), and various heuristics regarding loop depth.

See here:


Oh, very cool!

I think this sounds very similar to what we’re doing :sweat_smile: I hadn’t seen an approach similar to this in any prior literature, but it’s a pretty intuitive approach so I’m not too surprised.

Reading closer, I think it might actually be morally identical.

Any other good ideas on ways to optimize autodiff? :stuck_out_tongue:

This is a very cool feature! Formulating this problem as a min-cut is quite neat!

Quick question: Is there a way to configure recomputable_ops based on their compute / data bandwidth ratio? It looks to me that the current recomputable_ops list is hard-coded based on the types of operators, which could miss some optimization opportunities.

For example, let’s say that we are looking at A @ B with two configurations:
Config. 1: A.shape = (100, 100), B.shape = (100, 100), Dim_M = Dim_K = Dim_N = 100
Config. 2: A.shape = (100, 2), B.shape = (2, 2), Dim_M = 100, Dim_K = 2, Dim_N = 2

Config 1 and 2 are both MatMuls and should be unrecomputable_ops. However, Config 2 makes compute / memory footprint of A @ B look more like an element-wise operator in the sense that ops per Dim_M is very low (Dim_K and Dim_N << Dim_M). Therefore, Config 2 actually makes A @ B a good candidate to recompute.

If we want to make Min-cut recomputation shape-aware, what in your view could be a good way to approach this?


@kelayamatoz I think it would be totally feasible to simply modify the “ban recomputation” algorithm to add a check for operators we consider compute-bound (like matmuls) and recompute them if they’re actually bandwidth bound.

We actually have a somewhat similar check in the other direction - we ban recomputing anything (primarily reductions) where the output is >4x smaller than the input shape. The idea here is that although reductions are always bandwidth bound, it’s possible that a composition of say, broadcasting + reduction (i.e. which includes matmuls) isn’t!

1 Like

Awesome! Thanks for the explanation! (And thanks so much for getting back to this post on a Friday night. :stuck_out_tongue: ) Let me research more on these directions.

I have two questions:

  1. Is torch.compile with triton backend support the recompution?
  2. How deal with the random ops in the recompution? For example dropout, bernoulli,rand_like…