TorchDynamo Update 11: Making FSDP and Dynamo Work Together


TorchDynamo Post Series

This is another post in the series about TorchDynamo. Here are the previous dynamo updates:

Refresher: What FSDP Does

Fully Sharded Data Parallel (FSDP) lets you train huge models that won’t fit entirely inside one GPU, without much more than applying a wrapper and some hints to your PyTorch model.

It’s a form of Data Parallel training, meaning it requires multiple GPUs to work, but unlike simple Distributed Data Parallel (DDP) which keeps a whole copy of the model on each GPU, FSDP shards the whole model across all the GPUs. It just needs to save enough working memory per-GPU to gather up all the states (parameters, etc) for one layer (or FSDP Unit, can be multiple grouped ‘layers’) at a time.

But for this to work well, not only does FSDP need to modify your model to reduce the sizes of parameter and optimizer state tensors, it has to carefully orchestrate the timing of collective communication operations (all_gather, reduce_scatter) to rebuild the full local copy of a parameter or state just before it is needed, and then free it right after.

FSDP makes heavy use of two features of PyTorch: nn.Module structure, hooks, and overriding nn.Module forward(). The former lets FSDP reason about ownership of parameters and memory, and make assumptions about where to find parameters it needs to modify (shard). The latter provides a mechanism for FSDP to run special code just before and just after running key parts of the model, such as the forward function.

Enter PyTorch 2.0 and TorchDynamo

TorchDynamo is the graph capture frontend that powers PyTorch 2.0. It works by understanding just enough about python to capture straight-line sections of PyTorch operations and lower them to a compiler backend, but also seamlessly falls back to running parts of the code it doesn’t understand natively in CPython. This tradeoff lets it capture most of the value (regions suitable for optimization) while working flexibly enough to fit most programs without modification. But it also poses challenges for some of the existing infrastructure powering FSDP.

TorchDynamo lets a backend optimize and compile code to replace python code for specific parts of the program, and execute other parts natively in python. It has the capability to analyze the original python program opcode by opcode, and reason safely about conditions it must verify at runtime in order to validate assumptions made by the optimized code.

However, doing this level of analysis (and the number of runtime checks that it might imply) at runtime can be too expensive. To balance this out, Dynamo ‘specializes’ its behavior on some common library components in PyTorch, such as the ‘nn.Module’. For nn.Modules it encounters, it assumes they will not be mutated (modified) in certain ways since this is uncommon, and rather than adding runtime checks on every property and parameter inside the module and its recursive children modules, it bakes the module into its compiled code and adds a hook to the original module’s setattr to invalidate the compiled code if the module is modified.

What If We Just Put Them Together?

Let’s ignore all of the above and just try to run an FSDP model with TorchDynamo.

simple wrap, without sharding - kind of like basic DDP


recursive wrap, with sharding - extra comms to gather sharded layers

   FSDP(model, auto_wrap_policy=<my_policy>)

TorchDynamo will treat FSDP as any other nn.Module and attempt to (1) deepcopy it with FakeTensor parameters, (2) use FakeTensor to evaluate the output shape/dtypes of its forward(), (3) use the copied module as a safety catch in case it mutates itself during tracing, (4) gather up the ‘Parameter’ tensors in the module so they can be specially handled by AotAutograd.

The first problem (P1), which is probably trivial, is that FSDP wrapper can’t be deepcopied, because it holds a ‘ProcessGroup’ object that is not pickleable. ProcessGroup could be made pickleable, but ultimately P1 won’t be the most important issue and can be worked around.

The second problem (P2) is that TorchDynamo currently does not support nn.Module hooks in its specialized NNModuleVariable implementation. (See gh issue #91665 for ongoing work to support hooks). Recall that FSDP relies on hooks to orchestrate communication ops that materialize parameters just before calling .forward(). Right now, these hooks will get skipped by TorchDynamo rendering FSDP totally broken. However, if we did support the hooks then I suspect we’d uncover additional problems where the behavior inside the hooks is difficult for Dynamo to handle due to how it mutates the nn.Module. (And we do plan to add hook support to TorchDynamo, as a matter of course).

The third problem (P3) is that in step (4) the parameters TorchDynamo extracts and attaches to its FX graph for AotAutograd are not going to be the same parameters FSDP actually uses at runtime- FSDP will replace them with views into a flat buffer to make communication ops more efficient.

The fourth problem (P4) is that collective operations can’t be traced into the graph by TorchDynamo or handled by AoTAutograd, due to (a) missing support for ProcessGroup in TorchDynamo, (b) lack of functional versions of collectives, meaning Functionalization can’t run on them, (c) ‘Work’ object returned instead of Tensor violates FX graph assumptions and isn’t handled by AotAutograd or backends. Unlike in DDP, these operations happen during forward and backward, thus necessitating graph breaks even for functionality, not just performance. (See RFC #93173 PT2-Friendly Traceable, Functional Collectives for more information).

Making it Actually Work

The current solution for FSDP + TorchDynamo breaks down into three components:

  1. Only compile the original nn.Module wrapped inside FSDP, graph-breaking between layers and executing FSDP wrapper code eagerly (#87420)

  2. Add special handling for the parameter-views of modules wrapped by FSDP, so they get properly fed to AotAutograd (#88781, #89523)

  3. Compile the original nn.Module as Unspecialized #89330

Step (1) goes a long way in solving the original list of problems. P4 disappears as a problem if we already assume graph-breaks, since the collectives can just be executed eagerly in between compiled graphs. P1 and P2 are also fixed, since eager execution of FSDP library code calls the hooks and TorchDynamo won’t attempt to deepcopy the FSDP wrapper.

Unfortunately, P3 remains an issue because TorchDynamo is still compiling the original nn.Module that FSDP wraps, making assumptions about its parameters which FSDP will then violate when its hooks run.

(2) gets us part way there, at least making AotAutograd aware of the parameters it needs to handle. The issue here is that FSDP actually removes the parameter tensors from the original wrapped module and then adds new module attributes with the same names which are not actual parameters, but views into a flat buffer, and this hides them from TorchDynamo and AotAutograd. There is a fairly straightforward hack in place now, where FSDP adds hints that TorchDynamo looks for and then lumps the ‘non-parameter-view’ tensors in with the other parameters. Note: this step also adds the restriction that FSDP’s “use_orig_params=True” flag is set.

So far we have succeeded in making FSDP run correctly and the only anticipated performance hit is from graph breaks, which as demonstrated in the DDPOptimizer case should not be a big deal for most models. Unfortunately, a major performance issue popped up when benchmarking. For some reason, collective operations during backward were not overlapping with compute operations. Even though the graph-breaks were working and there were several compiled backward segments, the communication operations that should have interleaved and overlapped were running at the end and serializing.

This truly puzzled me, and to debug it I had to dive into the guts of Autograd and understand how it schedules operations based on the reverse order they were created during forward. See PR comments for an illustrated explanation.

The problem was one that FSDP already worked around in its eager implementation: the ‘parameter-views’ the FSDP wrapper stuck onto the wrapped module would naively be fine to reuse, and could be created just once during wrapping. FSDP relies on backward (autograd) hooks to kick off reduce_scatter operations on the underlying flat buffer after all their linked parameter-views receive local gradients. In order for Autograd to properly fire the hooks as soon as a layer finished producing local grads, it is critical that the parameter-views get recreated on each forward rather than once at wrapping time.

TorchDynamo’s ‘specialized’ NNModuleVariable handling makes the same naive assumption: it’s fine to grab references to the parameters of a module once during compilation, since parameters are usually long lived tensors. Indeed, this works in a functional sense since dynamo holds ‘views’ into a flatbuffer, and the views remain valid later. But Autograd considers all these views as ‘created at time zero’ and thus will never prioritize their gradient hooks.

At this point, I threw up my hands and thought, well what if we just don’t make any assumptions at the dynamo layer? There is already an ‘UnspecializedNNModuleVariable’ that dynamo can fall back to when the specialized version runs into trouble- so let’s just treat all FSDP wrapped modules as ‘Unspecialized’. Now, TorchDynamo now correctly feeds the newly created parameter-views as inputs to each call, solving our problem.

Current Limitations

Graph breaks

The main problems (or, potential problems, as they generally turn out to be OK) with graph breaks are summarized as follows:

  1. performance loss due to missed optimization opportunities
  2. rendering a model non-exportable
  3. exposing issues with functionalization + AotAutograd by compiling more graphs with more inputs, any of which could be mutated
  4. similarly, increasing memory usage under CudaGraphs due to having more inputs for more graphs.

(1) is not a big deal in practice for typical models as inductor optimizations are generally local in scope and FSDP breaks models into relatively large chunks anyway. (2) is a non-issue for many users, but it depends on the use case. (3) should be a short term, fixable problem, as these are literally bugs in other parts of the stack which should be fixed if discovered. (4) hasn’t come up as a problem so far, but could affect models that are right on the edge of memory capacity.

In summary, graph breaks probably won’t be a blocker for most users, but if they are, it’s not trivial to fix this since it would require new solutions for all of (P1, P2, P4). Note that P4 (traceable collectives) support is under way, and once they exist it will be possible not only to try to fix existing FSDP but also to support FSDP natively inside the compiler.

UnspecializedNNModule Compilation

The downside to TorchDynamo treating wrapped modules as Unspecialized is that it introduces more guards that must be evaluated at runtime. The performance impact of these extra guards has not been characterized recently, and may be a minor or major issue depending on the model. We plan to benchmark this overhead on important models.


Thanks for reading! Hopefully this post explained the challenges and solutions to composing FSDP with the PT2 stack. Please let us know how this is working for you in practice and file any issues on the pytorch github repo.