Compiling the optimizer with PT2


  • Compiling optimizers improved performance on all benchmarks: HuggingFace +18%, TorchBench +19%, and TIMM +8% E2E
  • Compiled optimizers now support cudagraphs
  • Averaged over all models in a test suite, average compile time on the benchmarks increases by about 40s for each test suite; in-progress optimizations could bring this below 30s

Motivation and Background

Optimizers are an integral part of training, every model uses an optimizer in some form or another to update parameters according to the model’s gradients. Nvidia published results that showed fully fusing the optimizer with hand-tuned kernels results in an 8-10% end-to-end speedup on large transformer models. Generating a fused optimizer with a compiler should show similar performance benefits with the additional flexibility of automatically updating the implementation with any changes to the python code, as well as providing speedups on new optimizers that don’t have hand-tuned kernels. In addition, the current eager optimizers frequently have to handle edge cases e.g. grouping parameters by device and dtype in order to support horizontal fusion; these rules can be automatically applied in the compiler instead, improving readability and simplicity of the codebase.

Design Challenges

Single Tensor vs Multi Tensor

Each optimizer has a single tensor implementation which loops over each parameter computing the update and applying the update one parameter at a time. There is also a multi-tensor implementation which horizontally fuses each operation across iterations to improve performance, but still reads and writes all parameters from memory for each distinct torch operation. The initial approach that was taken was to attempt to compile the single tensor version into a fully-fused kernel, but this proved to be difficult in the PT2 stack due to a number of issues:

  • Unrolling the for-loop takes a very long time in dynamo (> 5 min), and we’re tracing the same ops over and over again
  • Searching that large traced graph for fusions also takes a long time in the scheduler, and if an op is accidentally fused across iterations too early, it can be catastrophic for performance
  • Tracing initialization of the optimizer state takes a long time, again this is a loop over all parameters in the model setting up the optimizer state
  • Mutation introduces unbreakable dependencies between nodes (to ensure mutations occur in the correct order)
  • In eager, there is a lot of python math (getitem) which is faster in eager, but difficult to trace and not cudagraphable

Due to the above issues, especially the long tracing time and scheduler time, the approach that was taken is to compile multi-tensor. This itself had its own challenges though, namely:

  • Initialization still takes a long time because this initialization is shared across both single and multi-tensor implementations
  • Inductor at the time did not support efficient code generation for the foreach ops which comprise the multi-tensor implementation
  • Not all ops have foreach implementations
  • Mutation still introduces unbreakable dependencies between nodes
  • There are still a lot of getitem calls in loops for ops that don’t support foreach

Although multi-tensor has similar issues to single-tensor, the main benefit is that both tracing time and scheduling time are significantly reduced because the foreach ops introduce structure to the graph. Rather than dynamo having to unroll the loop over gradients as in single tensor, ops in the optimizer operate on the lists of gradients directly, so there are fewer nodes in the graph. Directly compiling the foreach ops allows users to indicate to inductor that ops can be horizontally fused, so the scheduler no longer needs to search for these fusions. The following sections detail how the remaining issues were resolved.

Improving Optimizer Initialization Time

The initialization time of the optimizers has consistently been a problem, with the main slowdown originating from dynamo tracing through a loop which sets up the optimizer state for each parameter in the model. This loop doesn’t contain any ops to be traced in the graph; it usually sets up tensors to track running statistics that are used to compute parameter updates. Initially dynamo was disabled on this function to force running it in eager, but in the long term, full model compilation should not have any graph breaks. In order to remove this graph break the key insight is that this initialization is only run once in eager, on the first call to optimizer.step(). On subsequent runs, this is skipped because the optimizer states have already been setup. So within dynamo when encountering a call to the initialization we run this within dynamo without tracing it, because it only mutates the optimizer object to store the per-parameter states. This sets up the optimizer state and then dynamo traces the step function normally. This matches the externally observable behavior of eager where the initialization also happens only once on the first call.

Compiling Foreach Ops

The main missing feature in inductor for multi-tensor optimizer compilation was efficient codegen for the foreach operators. The goal of this work is to yield a kernel which assigns fixed-size tiles from multiple input tensors to a single triton program id to perform the computation on these tiles. Since these inputs are not contiguous chunks of memory, there needs to be a mapping which maps input tensors to the program id. For example, for input lists of tensors A, B and output list C, in C = torch.foreach_add(A, B), the PID mapping could look like this:


Each PID is assigned a chunk of a tensor from A, B, and computes a chunk of a tensor in C.

During the lowering phase, the main complexity is that FX does not have a first-class notion of a list. List input arguments for the foreach ops consist of unique immutable list objects of get_item nodes, even if the upstream node was a foreach op. So even if a list consists of the same nodes, it is a separate python object in the graph. With this in mind, the foreach op itself is represented as a list of pointwise ops which are registered on the GraphLowering object, to be used later by the scheduler. With this representation, upstream and downstream pointwise ops can be fused vertically with the foreach node, providing that the normal rules for vertical fusion are met.

Within the scheduler, a pass coalesces all of the buffer lists that were registered during lowering into ForeachKernelSchedulerNodes - a subclass of the FusedSchedulerNode that restricts fusion to only vertical fusion and internally stores a list of SchedulerNodes operating on the tensors at corresponding indices of the input lists. In order to check whether a fusion is legal, the writes performed by each internal SchedulerNode must match the reads of the consuming SchedulerNode at the same list index. Additionally, the normal vertical fusion rules must permit fusion at each index of consumer and producer SchedulerNode lists. If these conditions are met, the ForeachKernelSchedulerNodes are vertically fused into a single ForeachKernelSchedulerNode where the corresponding pointwise ops at each list index have been fused. This is shown below.

With this fusion implemented, a sequence of foreach ops can be fused into a single kernel, enabling full fusion of the multi-tensor optimizers.

Handling Mutation

In order to save memory, the optimizers often use in-place ops instead of out-of-place variants. This is possible because at face value the optimizers only mutate the parameters as a function of the gradients, so it should be possible to not perform any additional allocations. In order to handle mutations, inductor introduces special dependencies (StarDeps) on mutated buffers to ensure that all consumers of the un-mutated value are run before the mutation occurs. An example of this is shown below.

Since these dependencies are never removed, this prevents fusions in certain cases, inhibiting fully fusing the optimizer’s kernels. The main issue here is that these dependencies should be able to be removed after there are only two nodes left in the graph. If all of the readers of the pre-mutated value have been fused into the pre-mutated value, it’s okay to fuse with the downstream mutated value, since the ops will still occur in the correct order. This insight led to the implementation of WeakDeps - dependencies that can be removed when they’re the only dependency left between two buffers. After each fusion pass, a pruning pass is run to remove any WeakDeps that are preventing final fusions. The final result is shown below.

Cudagraphs Compatibility

Since foreach kernels can be quite large, if there are many small tensors the overhead of setting up the arguments to each generated kernel can become significant. The main purpose of cudagraphs is to solve this problem by recording kernel launches and replaying them on the GPU without the CPU setup overhead each time the graph is invoked. Implementing cudagraphs compatibility occurred in two phases: initially ensuring that each optimizer’s torch operations were cudagraphable, and then benchmarking the results to ensure performance was actually improved.

For the first step, initially only Adam and AdamW had a cudagraphable implementation, which the user manually would set by setting the capturable argument in the optimizer constructor to True. This flag swaps fast scalar ops which are incompatible with cudagraphs with ops that may invoke additional kernel launches but will be correctly recorded by cudagraphs. The main example of this is getitem, which retrieves a single scalar from a tensor and allows this value to be manipulated with normal python math. This is in general much faster than launching a gpu kernel to perform the same operations on a scalar value. Capturable implementations for NAdam and ASGD were implemented and the remaining optimizers did not have python math in their eager implementations, so they did not need explicit capturable implementations.

Once this work was completed, enabling cudagraphs surprisingly resulted in a significant slowdown. Upon further investigation this was the result of the optimizer states not being stored in cudagraph-owned memory. For correctness across replays cudagraphs relies on the memory addresses of tensors being consistent across runs - in the case of inputs these are usually copied into cudagraph-owned memory to ensure this invariant is met. If the inputs are generated from an upstream cudagraph however, these copies are unnecessary because those tensors already reside in cudagraph-owned memory. In the case of the optimizer though, the optimizer states are considered inputs to the graph, but because they aren’t generated from an upstream cudagraph, each of the parameter states were copied on every iteration, reducing perf catastrophically. In order to mitigate this problem, the mark_static_address API was implemented to allow a tensor to be marked as static across runs of the compiled graph. This allows the tensor to be lifted as an attribute on the graph instead of an input. Graph attributes are constant across runs, so the copies will no longer occur. To use this API automatically, dynamo marks the optimizer states as static addresses. To prevent memory leaks of the optimizer states, weakrefs are used to ensure that once the optimizer object goes out of scope, these extra attributes will be cleaned up.

Compilation Time

Compilation time has been a challenging issue with compiling the optimizer, mainly because models with a large number of parameters tensors result in very large optimizer graphs, which simply take a long time to compile with dynamo and inductor. To evaluate compile time improvements compiling Adam with 1000 parameters was selected as a benchmark baseline because Adam is the most popular optimizer, and because 1000 parameters is about the number of parameter tensors for the largest models in the OSS test suites; if this benchmark is able to be reduced significantly, then the average compile time will improve similarly.

At first measurement in July, the baseline took > 400s to compile. Profiling was then used to identify and remedy several bottlenecks.

The main bottleneck initially was determined to be inductor’s scheduling, so the initial focus was to reduce the number of nodes reaching the scheduler and improving time to make fusion decisions. For reducing fusion time, profiling the scheduler revealed that the epilogue copies that occur after functionalization were taking a long time to determine where they should be fused within a foreach op. To improve this, the data structure tracking dependencies for foreach ops was improved to allow fusion decisions for a single op into a foreach op to be made in constant time. This resulted in a significant speedup to ~289 seconds.

Another bottleneck that was identified was the depth-first search which is performed in the pattern matcher for batch fusion. Upon further analysis, excluding certain types of nodes that are common in the optimizer from the search yielded an additional 100s compilation time speedup, without reducing the compiled model’s performance.

The current time to compile 1k parameter Adam stands at about 178s, with the average compile time increase on the benchmark suites at ~40s for each suite.

Below are 3 in progress optimizations to further improve Adam compilation time.

Future strategies:

  1. FakeTensorPropagation caching (in progress) (~25s)
  2. Fuse copy_ nodes during lowering to improve scheduling time (in progress) (~36s)
  3. Remove replace_all in dynamo to improve tracing time (10-15s estimated)

With prototypes of (1) and (2), the compilation time of an Adam optimizer with 1k parameters is reduced to 117s, down from an initial 400s in July. Roughly broken down, the time taken is ~28s in tracing, ~19s in AOT, ~70s in inductor. (3) is under way, and potentially can reduce compilation time by another 10-15s by improving tracing time. The next step is to improve the prototypes of these optimizations to perform full average compile time measurements of the test suites.

End-to-End Results

Many OSS models have significant speedups due to compiling the optimizer. The testing methodology compares two speedups 1) the speedup of the compiled model with eager foreach Adam optimizer over the eager model with eager optimizer to 2) the speedup of the compiled model with compiled optimizer over the eager model with eager optimizer in order to isolate the improvement from only compiling the optimizer portion of the model. The highlights are shown below (full data).





Compiling the optimizer resulted in large performance gains across all of the OSS test suites, with HuggingFace +18%, TorchBench +19%, and TIMM +8% E2E. There are some rare cases where the compiled optimizer was slower than its eager counterpart by a few percent, and I think these cases could be due to inadequate tile size selection for the foreach ops, as well as not grouping tensors of similar size in the generated foreach kernels. Both of these optimizations could improve the performance even further.

Running these experiments on our benchmarks also revealed some pre-existing numerical instability in the OSS test suites: when adding the non-compiled Adam optimizer to our benchmarks to compute the baseline speedup, about ⅓ of models fail accuracy. As expected, when compiling the optimizer, this behavior is reproduced exactly. Although these accuracy issues aren’t related to the compiled optimizer, the common underlying cause should be addressed, with initial analysis suggesting possible anomalies in the generated data resulting in NaNs or division by zero in both eager and compiled Adam.

What’s next

  • There are currently issues with recompiling the optimizer continuously with the LR Scheduler, this will require wrapping the LR in a tensor that is propagated to all devices
  • Adding autotuning for foreach ops
  • Integrating full model benchmarks into the torchinductor dashboard with a variety of optimizers
  • Improve average compile time across test suites to < 30s (vs 40s today)
  • Support dynamic shapes with foreach compilation
  • What does it take for new experimental optimizers to benefit from this work? Do they need to have both single-tensor and foreach versions? (e.g. the benefit of old, eager, single-tensor versions was that they were copy-pasteable and hackable)
  • Is it somehow possible to auto-gen foreach versions? (e.g. write the foreach versions in python and have a “specialized” compiler to handle these?)
  • I also wondered whether it’s possible to introduce a TensorList type and somehow let torch.* method to dispatch to single-tensor or TensorList-versions depending on the input - then more code can look identical for single-tensor and for foreach-versions

Thanks :slight_smile:

1 Like

These are great questions.

  • For experimental optimizers, as long as your optimizer consists of a straight-line sequence of foreach ops, torch.compile will generate a fully fused kernel. If this doesn’t work for some reason, or there are some ops that are missing please file an issue and tag me. You don’t strictly need the other optimizer modes like the single tensor version or the grouping code or error checking, all of that is done in the compiler if you only want to support PT2. If you want to support eager mode as well, you’ll have to add the extra checking that’s done in the other optimizers.

  • This is certainly possible, and there is some initial exploration around adding control flow primitives to explicitly allow a user to have dynamo represent these ops in the FX graph. In the meantime it is also possible to add a decorator where dynamo annotates these python foreach implementations when tracing, and then inductor can apply the same logic it uses on a canonical foreach op to fuse there too. Let me know if this is something you really need or if it is more of a nice-to-have.

  • We talked about this idea on slack a little bit and there’s also this TensorDict RFC. I’m pretty intrigued by this idea, and think it would be cool to implement the idea of a “pytree-able” container that you just do a bunch of pointwise ops on. Adding support for this in PT2 shouldn’t be too difficult, but adding 1000s of nodes to the graph from these containers could make compilation pretty slow, it already does for lists. Right now it isn’t a priority for me, just because lists are getting us the perf in PT2 right now, and this is more of a smaller UX improvement, but if you’re interested in driving it I’m happy to discuss more and provide support. cc @janeyx99 this has come up in a few places recently

Let me know if you have other questions or want to discuss these ideas in more detail

1 Like

Hi @mlazos , could you please elaborate on how to wrapping the LR in a tensor so that we can walkaround the recompiling problem. I met the same recompiling problem when compile the optimizer step function.

Hi @gouchangjiang, can you send me your code? In general, wrapping the LR is needed for composing with the LRScheduler. In general it shouldn’t be needed. Can you paste a log after running your code with TORCH_LOGS=“recompiles” ? Please make this as minimal as possible.

With TORCH_LOGS=“recompiles”, it complains:

  • torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function adamw in /opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/optim/, triggered by the following guard failure(s): - L[‘lr’] == 0.0002999999984974876 # func( # optim/ in adamw

  • torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function adamw in /opt/conda/envs/python3.7/lib/python3.8/site-packages/torch/optim/, triggered by the following guard failure(s): - L[‘lr’] == 0.0002999999984964864 # func( # optim/ in adamw

Ok, this is what I expected. So if you pass torch.tensor(lr) to the optimizer constructor, this should solve the problem for now. Are you using an LR Scheduler btw?

yeah, you are right. Passing torch.tensor(lr) to optimizer constructor works. I am using CosineAnnealingLRWithWarmup(LambdaLR).

Okay, as long as you’re just compiling the step you should be okay. Compiling the LRScheduler itself might have a similar issue with recompiles because it also uses python scalars as well. I’m working on getting the LRSchedulers to compile by swapping these out for tensors.