Rethinking PyTorch Fully Sharded Data Parallel (FSDP) from First Principles

Given some interest, I am sharing a note (first written internally) on the PyTorch Fully Sharded Data Parallel (FSDP) design. This covers much but not all of it (e.g. it excludes autograd and CUDA caching allocator interaction). I can share more details if there is further interest.

TL;DR

  • We rethought the PyTorch FSDP design from first principles to uncover a new one that takes a first step toward improving composability and flexibility.
  • This includes an experimental fully_shard API that is part of a broader eager distributed composable API effort. This is a work in progress and not ready for general use yet.
  • This includes the ability for the existing FullyShardedDataParallel to expose the original parameters (not FlatParameter s) via use_orig_params=True , which enables flexible support for multiple parameter groups.

Introduction

In data parallelism , parameters are replicated across ranks, and each rank operates on a local batch, computing local gradients. Ranks all-reduce their local gradient to compute the global gradient and run a replicated optimizer step.

In sharded data parallelism , parameters are sharded across ranks, and ranks all-gather them as needed for unsharded forward/backward computation and free them when unneeded. Since each rank only manages a shard of the parameters, each rank only needs the corresponding gradient shard to run the optimizer step. This implies that ranks reduce-scatter their local unsharded gradient to compute the sharded global gradient and to run a sharded optimizer step.

In sharded data parallelism, only the memory of the sharded parameters, sharded gradients, and sharded optimizer states plus some memory for the largest subset of simultaneously unsharded parameters contribute to a rank’s peak memory. These memory savings may be critical for training large models or scaling batch size.

In fully sharded data parallelism, the sharding factor is specifically set to the world size. I.e., the parameters are sharded across all ranks without replication. This algorithm is commonly called ZeRO-3, and PyTorch’s Fully Sharded Data Parallel (FSDP) is one implementation, where a central challenge is working within the PyTorch framework. (The sharding factor need not be the world size; setting it to be the number of intra-node devices gives the alternative Hybrid Sharded Data Parallel (HSDP).)

PyTorch FSDP was upstreamed from Fairscale FSDP. This note provides one possible perspective on this evolving FSDP design from first principles and motivates a modified design striving for improved composability. Throughout, we treat performance as a first-class constraint since a non-performant design is not usable.

Constraints: Communication

FSDP targets Nvidia GPUs and uses NCCL for collective communications, and FSDP requires all-gather to unshard parameters and reduce-scatter to reduce gradients. PyTorch’s all_gather_into_tensor requires even input sizes across ranks but is more performant than all_gather and similarly for reduce_scatter_tensor vs. reduce_scatter. Moreover, for a fixed communication volume, batching data and issuing fewer collectives is more performant. This yields the following two constraints:

Constraint 1: FSDP should communicate even sizes across ranks to use all_gather_into_tensor / reduce_scatter_tensor .

Constraint 2: FSDP should batch parameters for all-gather and gradients for reduce-scatter.

FlatParameter

Constraints 1 and 2 motivate the FlatParameter abstraction: a 1D tensor that is the concatenation of n flattened original parameters (with optional right-padding to ensure evenness). It serves as FSDP’s atomic unit of communication, and it owns the storage of the constituent original parameters.

When computation involves an original parameter, the owning FlatParameter must be unsharded and can only be resharded outside the computation. Thus, a performant FlatParameter construction groups original parameters involved in computation around the same time, and ideally each FlatParameter is only unsharded/resharded once for a given forward or backward to minimize the number of all-gathers.

Key Question: For a given model, how should FSDP construct FlatParameter s from the model’s parameters?

FlatParameter Construction

FullyShardedDataParallel is a module wrapper, applying transformations to the wrapped module at construction time and runtime (i.e. forward/backward/optimizer step). As an eager API, it only has access to the static module structure at construction time. FSDP wrapping leverages the module structure to inform the FlatParameter construction, in hope that model authors group parameters with the desired locality into the same module or module subtrees.

Rule 1: If the user wraps fsdp_module = FullyShardedDataParallel(module) , then every parameter in module not already flattened is flattened into a single new FlatParameter and assigned to fsdp_module .

Rule 1 is just one way to leverage module structure, which happens to be sufficiently performant while being simple to reason about. From a model print out, a user/developer can quickly infer the parameter assignment. This simplicity can be crucial to debugging and adoption. An example module tree is shown above, where red indicates a FullyShardedDataParallel -wrapped module and yellow indicates a non-directly-wrapped module. This wrapping constructs four FlatParameter s (shown by the dotted lines), which are assigned to modules 0, 1, 3, and 7, respectively.

To be clear, if a user only applies FullyShardedDataParallel to the root module, then there is only a single FlatParameter consisting of all original parameters, meaning no communication/computation overlap and the entire unsharded parameter size contributes to peak memory. Thus, we need multiple FlatParameter s, which by Rule 1, requires recursive/nested FSDP wrapping.

Manual & Auto Wrapping

Manual wrapping refers to applying submodule = FullyShardedDataParallel(submodule) on target submodules, where this application proceeds bottom-up. For large/complex models, manual wrapping may not be tractable, so we introduced auto wrapping to use heuristics to perform the nested wrapping automatically. Heuristics include a parameter size threshold and target nn.Module classes to wrap.

Historical: The original design employed a double module wrapping with a second module wrapper: FullyShardedDataParallel(FlattenParamsWrapper(module)) . FullyShardedDataParallel was responsible for all-gathering, and FlattenParamsWrapper was responsible for setting the original parameters to be views into the all-gathered FlatParameter . We consider this all-gather plus view-setting together to be the (logical) unshard .

Constraints: Unsharding Parameters for Computation

Constraint 3: For correctness , a module’s parameters must be unsharded during its forward/backward computation and only resharded outside that. For memory performance , the unsharded lifetime should be minimized, and for throughput performance , the number of unshard/reshards should be minimized.

Pre/Post-Forward/Backward

To address the throughput part of Constraint 3:

Rule 2: For a given FlatParameter and forward/backward pass, FSDP only unshards and reshards the FlatParameter once.

This ensures the minimal number of all-gathers for a fixed number of FlatParameter s without changing the algorithm, namely 2x the number of FlatParameter s.

Given Rules 1 and 2 and Constraint 3, the choice of when to unshard/reshard becomes fixed. For the module that owns a FlatParameter (unique by Rule 1), FSDP should unshard right before the module’s forward and reshard right after. Similarly, FSDP should unshard right before gradient computation for any tensor in the module’s forward output and reshard after the FlatParameter 's gradient computation. This defines four points per FlatParameter : pre-forward unshard, post-forward reshard, pre-backward unshard, and post-backward reshard (in accordance with Rule 2).

FSDP also has some additional logic to run before the entire forward and some to run after the entire backward. These correspond to the root pre-forward and post-backward final callback.

New Abstraction: FlatParamHandle

Previously, we saw how an FSDP wrapping informs a FlatParameter construction. Under the existing design, each FullyShardedDataParallel instance manages the data for its one FlatParameter directly (e.g. the pre/post-forward/backward unshard/reshards). However, applying a FullyShardedDataParallel module wrapper is not actually necessary! I.e., we can achieve the same FlatParameter construction without imposing on the module structure itself.

Moreover, Rule 1 tightly couples one FullyShardedDataParallel instance with one FlatParameter . While this simplifies the design, this is not the best abstraction and can obfuscate that we are intentionally following this simplifying rule. In reality, once we have fixed a FlatParameter construction, we can have one entity orchestrate the FlatParameter s (e.g. their unshard/reshards), and we can lower the FlatParameter data management to another entity that is strictly per- FlatParameter . For example, the high-level entity can be FullyShardedDataParallel , but the lower-level entity need not be.

From this insight, we introduced another class, FlatParamHandle , that exactly performs the FlatParameter data management and is strictly 1:1 with it. FullyShardedDataParallel only needs to interface with FlatParamHandle s. Last year, we refactored FSDP internals to achieve the separation, which resulted in FlatParameter being nothing more than a plain nn.Parameter .

New Feature: use_orig_params=True

Part of lowering the FlatParameter data management to FlatParamHandle included enforcing a single code path for unsharding and resharding, respectively, regardless of the calling context (forward/backward, model checkpointing, etc.). This discipline allowed us to augment the unshard and reshard logic in a mostly “just-works” way.

The optimizer step conventionally operates on the registered parameters (returned by nn.Module.parameters() ). For the existing design, the FlatParameter s are registered, while the original parameters are de-registered and replaced by plain Tensor s. Hence, the optimizer step runs on the sharded FlatParameter s, and the original parameters are lost.

However, when the parameters are sharded, we can still define semantics that follow the single-program multiple-device (SPMD) paradigm expected for data parallelism. In particular, accessing an original parameter on a rank can return the shard that is present in the rank’s FlatParameter shard or an empty tensor if it is not present, and similarly, the parameter can receive a corresponding gradient if present in the rank’s FlatParameter shard or None if not. One caveat is that due to FSDP’s sharding algorithm, we can only return the flattened sharded original parameter, losing the tensor structure.

These semantics to use the original parameters are available today by passing use_orig_params=True to the FSDP constructor, and they were added exactly by augmenting the existing unshard/reshard logic. In that case, named_parameters() returns the original fully-qualified names (FQNs), not ones like <prefix>.flat_param . This enables using multiple optimizer parameter groups and/or different requires_grad within one FlatParameter 's original parameters, and this helps hide the FlatParameter abstraction from users. We hope to converge to setting use_orig_params=True by default in the future.

New API: From Wrapper to No Wrapper – fully_shard

The lowering of FlatParameter data management from FullyShardedDataParallel to FlatParamHandle and the ability to use the original parameters prepared FSDP for a more composable form. This coincided with a broader eager composable API effort, which imposes a contract requiring each API to preserve the module structure and the original FQNs from named_parameters() . Thus, FSDP adhering to the contract required the aforementioned work. Without it, there would be a dependency on the FullyShardedDataParallel module wrapper to manage FlatParameter data, and the FQNs would fail the contract, respectively.

With the abstractions and semantics aligned, we migrated the pre/post-forward logic to pre/post-forward nn.Module hooks, generalized existing FullyShardedDataParallel code over an _FSDPState object, and landed the resulting fully_shard API. We hope to provide more details on the composable APIs in the future.

We diagram the design changes above. Originally, FullyShardedDataParallel was 1:1 with FlatParameter , which consists of some number of original parameters. We added FlatParamHandle as a middle layer, as represented by arrow (1). Next, we can imagine replacing FullyShardedDataParallel with another high-level entity and register hooks on the module instead, as represented by arrow (2). This step only exists logically. Finally, we introduced fully_shard as the high-level entity, which can be 1:k with respect to modules/FlatParamHandle s/FlatParameter s, as represented by arrow (3).

Future Directions

Here, we focus on two future directions that reflect the need for design flexibility.

Revisiting FlatParameter Construction

Recall that the key question was how to construct FlatParameter s for a given model. One view is that there can be two (non-disjoint) approaches to further tackle that question:

  1. Improve our ability to search the existing set of constructions for performant ones.
  2. Expand our set of possible constructions to include more performant ones.

Following Approach 1 may include providing more guidance on how to apply FSDP or provide an improved auto wrapping policy. Following Approach 2 is more nuanced. For example, we may parameterize the set of constructions by (1) the number of unshard/reshard pairs per FlatParameter per forward/backward pass and (2) the number of modules involved per unshard/reshard pair. For this note, we focus on two classes in this parameterization:

  • Current class: (1 unshard/reshard pair, 1 module per unshard/reshard pair)
  • New class: (1 unshard/reshard pair, 2 modules per unshard/reshard pair)

A construction in the new class may be such that one submodule unshards and later a different submodule reshards, hence the 2 modules per unshard/reshard pair. This enables constructions where two sibling modules are grouped together into one FlatParameter , without including their parent module. Note that such a construction violates Rule 1 (but not 2), which means that the existing FullyShardedDataParallel wrapper is incompatible.

Last summer, we explored this new class of constructions and found promising throughput and memory gains. However, searching this class and choosing a performant construction pose a challenge. The algorithm we used was the 1st iteration’s execution order to choose the construction, but we have found several blockers to the robustness of this approach.

The above diagram shows the transition from today’s fully_shard corresponding to (1 unshard/reshard pair, 1 module per unshard/reshair pair) on the left to the one corresponding to (1 unshard/reshard pair, 2 modules per unshard/reshard pair) on the right. Because fully_shard can orchestrate the FlatParamHandle s, only needing to register hooks on modules, we may replace each singleton module that is 1:1 with each FlatParamHandle to be multiple modules that are now n:1 with each FlatParamHandle , where notably, these modules can be siblings.

Memory Efficient Fine-Tuning

In fine-tuning, the majority of model parameters are frozen, i.e. do not require gradient, saving gradient and optimizer state memory. For FSDP, the FlatParameter construct introduces an issue since it owns the storage of multiple original parameters. For use_orig_params=False , the user cannot even specify different requires_grad across original parameters corresponding to a FlatParameter . For use_orig_params=True , the specification is possible, but as long as one original parameter receives a gradient, the entire FlatParameter gets a gradient, where frozen parameters’ gradients manifest as zeros.

We mention one solution. We may relax Rule 1 and enable (up to) two FlatParameter s per module, one for the parameters that require gradients and for the those that do not. This trades off increased communication overhead (since up to 2x the number of collectives) for decreased gradient memory usage. Our implementation already accommodates multiple FlatParamHandle s per module except for a few places for model/optimizer state checkpointing, so really, the important questions are around how users should invoke this code path and if the complexity it introduces to the system is justifiable.

The above diagram shows the (up-to) 2 FlatParamHandle s per module setup on the right compared to today’s fully_shard on the left, which only permits 1 FlatParamHandle per module.

Final Thoughts

Rethinking the FSDP design from first principles can help expose new opportunities and prepare us for the future. For example, decoupling the sharding factor from the world size previously gave Hybrid Sharded Data Parallel (HSDP). Looking forward, if NCCL offers efficient all-gatherv/reduce-scatterv, relaxing Constraint 1, then FSDP can adopt a new tensor-shape preserving sharding algorithm that allows non-pointwise optimizers like Shampoo. Finally, we may leverage compiler techniques to give us a better way to search for performant FlatParameter constructions and/or support a richer class of constructions.

More broadly, eager FSDP serves as a learning ground that prepares us to design for more advanced parallelisms and their compositions. FSDP’s FlatParameter and FlatParamHandle simply represent a grouping of tensors sharing a contiguous storage and its data-managing entity, respectively—we may find these fundamental abstractions to be useful beyond FSDP.


Big thanks for the Fairscale FSDP team for creating such a performant and general implementation of ZeRO-3 and for supporting the upstream to PyTorch!

11 Likes

Awesome! and I’m a bit confused that, do we really know how to flatten a layer’s parameters if any of the parameters is non-continuous, e.g. channel last, or sparse tensors?

Good point! We do not support non-contiguous tensors at the moment.

Do you have any practical examples where a parameter is using channels last or sparse format?

Actually, no. I am somewhat interested in the channel-last format, but the current implementation of this feature sometimes confuses me, so I am not ready to use it. I am very interested in the discussions about FlatParameter constructor and Fwd/Bwd Prefetch in this article. In fact, I am also researching whether it is possible to break through module boundaries and search for the best splitting policy based on some kind of ‘graph’ structure. However, this is still far away. Currently, I am trying to integrate my team’s ‘cuda-like’ devices into the current FSDP implementation, at least to enable it to construct FSDP, so that I can evaluate the sharding effect base on my polic and FSDP wrapper in pytorch master.

1 Like

Thanks for sharing!! I want to make it clear that how can I use multiple FlatParamHandle s per module? I got an error “FlatParameter requires uniform requires_grad” when I train with lora.

The current design assumes at most one FlatParamHandle per module (but possibly multiple modules per FlatParamHandle). This corresponds to Rule 1 above since a FullyShardedDataParallel instance is 1:1 with a FlatParamHandle, which is 1:1 with a FlatParameter.

For the error you are hitting, maybe you can file an issue on Github with more context, and we can take a look?

@awgu Do you see this as also enabling FSDP: enhanced shared parameter support? Are there any plans for this?

Parameter sharing seems like it could be handled using multiple FlatParamHandle per module, similar to what you mention in your section on Memory Efficient Fine-Tuning.

Hi, I was reading this paragraph:

  • This includes the ability for the existing FullyShardedDataParallel to expose the original parameters (not FlatParameter s) via use_orig_params=True , which enables flexible support for multiple parameter groups.

I had a hard time understanding what use_orig_params meant. Does it mean that it allows us to include frozen params alongside trainable params? If that’s the case, do you think a better flag (perhaps for future major version) would be allow_frozen_params instead? Otherwise, I might be misunderstanding (please let me know if that’s the case!)

Thanks for the question!

use_orig_params=True does mean we can include frozen and trainable parameters together in the same FlatParameter. However, it does not only have to be for that purpose. Like the part you quoted says, you can use it to implement multiple optimizer parameter groups such as when different parameters have different weight decays (with all parameters trainable).