Torch.optim happenings: More Ws (2023 H2)

In the past half, torch.optim has literally added more Ws (with new configs NAdamW thanks to janEbert and RAdamW thanks to blizard) and has figuratively earned dubs with:

  1. reducing memory usage through minimizing intermediates, a new option for optimizer in backward, and shallow copying in load_state_dict,
  2. increasing speed with lerp, removing h2d syncs, and _foreach_copy (thanks @crcrpar),
  3. flexibility with optimizer state_dict hooks, lr scalar tensor support, and float64 scalar support,
  4. and robustness through resolving longstanding bugs (like foreach handling empty tensors), fixing complex handling for all optimizers, and improving testing with OptimizerInfos.

In this update, I will also summarize our progress on torch.compile() integration and what we envision as next steps.

reducing memory usage

For the 2.0 release, we defaulted to the more performant foreach implementation over the forloop (check out this explanation in the docs under Algorithms). A result of this horizontal fusion meant that the computation intermediates became significantly larger, causing peak memory to go up (see explanation in the issue). This is an expected tradeoff, since we need a minimum of 1 intermediate to avoid inplacing into buffers, but many of our optimizers were not conservative with memory usage. We took a survey of all foreach optimizers (Adam, AdamW, RAdam, NAdam, Adadelta, Adagrad, Adamax, Rprop, ASGD, RMSprop, SGD) and minimized their use of intermediates to decrease peak memory.

If even using the memory efficient forloop implementation causes OOMs during training, there is perhaps a way to further decrease memory usage for your training loop by moving the optimizer update into the backward pass! The trick here is to avoid piling up a buffer of gradients during the backward pass that all wait upon the completion of the backward pass to be fed into the optimizer step. Instead, immediately run the optimizer step once a gradient is ready to enable freeing it as soon as possible.

For visualization, here’s a before:

And after (note the new maximum Y axis):

We’ve added a Tensor.register_post_accumulate_grad_hook API to allow for this technique; see the tutorial: How to save memory by fusing the optimizer step into the backward pass — PyTorch Tutorials 2.2.0+cu121 documentation for whether this could help you.

Lastly, we rectified a past mistake where Optimizer.load_state_dict used to deepcopy the incoming state, thus requiring 2x state memory. Now, similar to nn.Module.load_state_dict, we make only a shallow copy to decrease memory usage by about half.

increasing speed

People who want speedy optimizers by using our fused optimizers (e.g., AdamW(fused=True)) along with CUDAGraphs can do so with no performance hits by passing in a scalar tensor for lr (over a Python number). Previously, the tensor lr would cause a host to device synchronization (.item() call) every time the lr changed, which unnecessarily held back performance. With kernel updates to accept tensor LRs in our fused optimizers, CPU time per 1k steps decreases by ~70ms on my local machine. @awgu also noticed similar unnecessary copy()’s for the optimizer step+=1 update in the common case where step is hosted on the CPU. We’ve since then bypassed these, decreasing CPU time by ~50ms for 1k steps.

The simplest optimization we’ve done this half is just to use lerp whenever possible and get fusion of an add and mul for free.

flexibility

Following an internal-only discussion with Distributed Checkpointing, we’ve enabled pre and post hooks on Optimizer.state_dict and Optimizer.load_state_dict in order to give users flexibility. With optimizer state_dict hooks, you can preprocess and post process the state dict for saving and loading.

If you’re not such a fan of tinkering with the state_dict, but you DO want to be able to save your fused optimizer state on CPU to load back into GPU, we’ve also fixed a bug with step casting. Hitherto, you would get an error saying RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu! (when checking argument for argument state_steps upon loading a CPU state_dict into a fused optimizer, but no longer! This should just work.

We’ve also widened support for optimizer scalar hyperparameters, allowing step to be float64 as well as float32 for use cases requiring more precision for step.

robustness

We once again met an old bug that we thought was fixed in H1. In essence, our multi_tensor_apply() function, which powers all our foreach operators by launching CUDA kernels on chunks of tensor lists, mishandled empty tensors in the input tensorlists. Once it met an empty tensor under certain patterns, it would halt and stop launching kernels on the rest of the tensors. The code was already complicated after several patches, including a prior attempt at fixing this exact bug, and we finally fixed the bug for real for real by refactoring the kernel entirely. Huge thanks to @ngimel for the quality review and r-barnes for polishing the style.

We’ve identified gaps in support for complex types and filled them in, by enabling and adding tests for complex on Adam, AdamW, NAdam, RAdam, and foreach Adadelta (the last three courtesy of @jon-chuang). Along the same vein, we’ve patched in testing for missing coverage on sparse, capturable, and maximize configs.

Upon noticing that it was too easy to accidentally skip testing important optimizer configs and that there was no uniform way to test all configs at once, we introduced OptimizerInfos in common_optimizers.py, in which we’ve consolidated our optimizer configs. Modeled after ModuleInfos, this new infrastructure auto-enables testing for devices like MPS and CUDA and makes for simpler testing of features across optimizers. We’re in the midst of porting our current tests over, and, in the process of re-enabling dynamo for the migrated tests, we’ve spotted several PT2 optim bugs (tracked with skipIfTorchDynamo in common_optimizers.py). We’ve also hopped on the train with the new infra to generate more exhaustive compiled optimizer tests, immediately improving our compiled optimizer coverage.

torch.compile() integration

Following our collaboration with the compiler team (thanks @mlazos), you can now try out torch.compile() for the optimizer.step() on a majority of our optimizers (Adadelta, Adagrad, Adam, AdamW, ASGD, NAdam, RMSprop, Rprop, and SGD). To support these, we now offer a capturable API for NAdam in addition to Adam(W). See details and benchmarks in his post: Compiling the optimizer with PT2. The eventual hope is for torch.compile() to support general foreach ops to achieve automatic vertical fusion without the need of handwritten kernels, so feel free to try it on your own foreach optimizer and let us know how it goes!

Additionally, big thanks to @jon-chuang for landing a series of heuristical changes chopping off compilation time for common use cases. There is still much to do for hacking down compilation time, smoothing out integration with LRScheduler, and increasing coverage, so please holler by filing issues!

next steps

There are many things we could still do to improve optimizers, and the following are some themes we’ve settled on for H1 2024.

We are committed to solidifying the user experience, so we will continue smoothing out the experience with torch.compile(), striving for composability with distributed (FSDP, DTensor), and addressing user reported issues. Specifically, this would involve looking into:

  • LRScheduler integration with compiled optimizers,
  • tensor subclass support,
  • and fully migrating our tests to OptimizerInfos (let me know if you’d like to help).

As finetuning becomes more ubiquitous, we’d like to empower more users by providing memory-efficient techniques for constrained setups. We’re exploring ideas such as

  • adding memory-efficient optimizers like Adafactor or smaller bit AdamW
  • mixed precision support in optimizer state
  • make optimizer step in backward more of a product

As an aside, we had also discussed allowing FQNs as keys in optimizer state_dict (see [RFC] Introducing FQNs/clarity 👓 to optim state_dict) which we concluded falls under the priority of a better engineering project.

thanks

I’d like to call out some solid BE (better engineering) work, which help the project stay hygienic:

Thanks to all who’ve contributed to these Ws! And to you for reading!

7 Likes

Has anyone from the team tried this technique with FSDP where the optim state is sharded? Does this fusion approach work in this scenario?

@raghukiran1224 Please correct me if I misunderstood–I’m guessing you are asking about torch.compile() composing with FSDP for optimizers?

In theory, the fusion should work independently of whether the inputs are full vs sharded for pointwise optimizers (all our foreach optimizers). Inductor should be able to fuse a series of foreach operations on the same memory without trouble. I am less sure about fusing with the ops before and after the optimizer step, as torch.compile() support on FSDP distributed collectives is underway (see an update here: Torch.compile() + FSDP - Dec 8th) and @awgu would know more about the ops directly before and after the optimizer portions in FSDP.

Practically speaking, since distributed collectives are not compilable today, there’s no full-graph capture of FSDP. In other words, only the code regions between the collectives (like allgather, reduce_scatter) would be compiled. A similar chopped up approach would have to take place for optimizer–>one would manually replace the normal optimizer.step() call with a torch.compile()'d version. In pseudocode (I have not ascertained this works in Python), I’m imagining something like:

model = ...
fsdp_mod = FSDP(model)
optim = AdamW(fsdp_mod.parameters())
optim.step = torch.compile(optim.step)  # WARN: this may be invalid python

That said, I suspect no one on the team has tried this yet, but we are actively developing so if you do get a chance to try it, please let us know how it goes! Also cc’ing @wconstab, @mlazos, @voz, @awgu who would know more.