Torch.compile() + FSDP - Dec 8th

Most of our updates on compiling FSDP have been internal. This is the first public post in this series.


What started out as a project around ~May-June is slowly getting to a place where the changes left to make are well scoped.

When we started, almost nothing in torch distributed compiled at all. Whole patterns used by distributed (tensor hooks, backward hooks, process groups, collectives, etc) were not well supported in torch.compile.

The process has been a little bulldozer-y, we didn’t do a particularly great job of scoping everything we would need to cover to get distributed working e2e with torch.compile, but we had a few major points in mind:

  1. Compiling FSDP was critical to this endeavor, we cannot declare success without this

  2. Compile would have to be e2e - that is, just the dense model /tensor compute portions of FSDP wrapped models would not be enough, we would need to make sure collectives make it into our compiled graph also

  3. The work, as usual, can be split into two halves - (i) soundly capturing a graph through both dynamo and aot_autograd and (ii) soundly lowering it through inductor

  4. torch.compile w/ FSDP is full-graph only, we do not plan to support graph breaks with FSDP at this time

  5. per-parameter FSDP is being worked on in parallel, and has many overlapping requirements. The per-perameter FSDP specific are tracked in Tracing per-param sharding FSDP · Issue #114286 · pytorch/pytorch · GitHub

At this point, we not only have a PR stack tracking the remaining changes ([Do not review] Top of FSDP stack - to be broken up by voznesenskym · Pull Request #115410 · pytorch/pytorch · GitHub (will be broken up into many smaller PRs) we have also achieved a ton of intermediary steps on the way there. These intermediary steps are useful and valuable on their own, and have significantly increased our coverage of what torch.compile can handle.

In no particular order, on our way through supporting FSDP, aside from a lot of bug fixes - we’ve taught torch.compile how to handle:

  • Functional collectives
  • Process groups
  • DTensor
  • Forward hooks on both input and intermediary tensors and modules
  • Backward hooks on both input and intermediary tensors
  • Compiled autograd (turning the autograd graph into an fx.Graph, and then passing that on to PT2)
  • Post accumulate grad hooks
  • Metadata mutating ops, esp w/r/t correctly resetting fake tensors for aot_autograd
  • Attribute mutations on tensors, other objects
  • Support for streams (graph capture only)
  • Support for functools.partial
  • Grad setting, updating

Graph Capture Progress

So, where are we now? We are not yet done, but the list of gaps to close for supporting FSDP looks relatively manageable and sane. All the hacks are gone, and almost all the changes (aside from a few small op-ification steps in FSDP itself) are within dynamo.

In no particular order, we need to land, but already done in the PR above:

  • Quite a bit of bug fixes
  • Better support for submodule, param, and buffer iteration for FSDP wrapped modules
  • FSDP - opificationfor certain storage comparison operations

Graph Lowering Progress

We’ve yet to start this for this workstream.

Next Steps

  1. Work on breaking apart the PR listed above into sane, rational, small pieces to get reviewed
  2. Land a battery of compile + FSDP tests to hold the line on our graph capture work
  3. Expand the scope of compile + FSDP to more complex models beyond the current set of toy models used during development
  4. Graph lowering the captured graph through inductor

Thanks to all the contributors to this effort so far - @ezyang, @bdhirsh, @jansel, @awgu, @Chillee, @albanD


It’s great to see a summarization blog like this!

In additon, are there any similar blogs for torch.compile support on DataParallel and DistributedDataParallel?