Given some interest, I am sharing a note (first written internally) on the PyTorch Fully Sharded Data Parallel (FSDP) design. This covers much but not all of it (e.g. it excludes autograd and CUDA caching allocator interaction). I can share more details if there is further interest.
- We rethought the PyTorch FSDP design from first principles to uncover a new one that takes a first step toward improving composability and flexibility.
- This includes an experimental
fully_shardAPI that is part of a broader eager distributed composable API effort. This is a work in progress and not ready for general use yet.
- This includes the ability for the existing
FullyShardedDataParallelto expose the original parameters (not
use_orig_params=True, which enables flexible support for multiple parameter groups.
In data parallelism , parameters are replicated across ranks, and each rank operates on a local batch, computing local gradients. Ranks all-reduce their local gradient to compute the global gradient and run a replicated optimizer step.
In sharded data parallelism , parameters are sharded across ranks, and ranks all-gather them as needed for unsharded forward/backward computation and free them when unneeded. Since each rank only manages a shard of the parameters, each rank only needs the corresponding gradient shard to run the optimizer step. This implies that ranks reduce-scatter their local unsharded gradient to compute the sharded global gradient and to run a sharded optimizer step.
In sharded data parallelism, only the memory of the sharded parameters, sharded gradients, and sharded optimizer states plus some memory for the largest subset of simultaneously unsharded parameters contribute to a rank’s peak memory. These memory savings may be critical for training large models or scaling batch size.
In fully sharded data parallelism, the sharding factor is specifically set to the world size. I.e., the parameters are sharded across all ranks without replication. This algorithm is commonly called ZeRO-3, and PyTorch’s Fully Sharded Data Parallel (FSDP) is one implementation, where a central challenge is working within the PyTorch framework. (The sharding factor need not be the world size; setting it to be the number of intra-node devices gives the alternative Hybrid Sharded Data Parallel (HSDP).)
PyTorch FSDP was upstreamed from Fairscale FSDP. This note provides one possible perspective on this evolving FSDP design from first principles and motivates a modified design striving for improved composability. Throughout, we treat performance as a first-class constraint since a non-performant design is not usable.
FSDP targets Nvidia GPUs and uses NCCL for collective communications, and FSDP requires all-gather to unshard parameters and reduce-scatter to reduce gradients. PyTorch’s
all_gather_into_tensor requires even input sizes across ranks but is more performant than
all_gather and similarly for
reduce_scatter. Moreover, for a fixed communication volume, batching data and issuing fewer collectives is more performant. This yields the following two constraints:
Constraint 1: FSDP should communicate even sizes across ranks to use
Constraint 2: FSDP should batch parameters for all-gather and gradients for reduce-scatter.
Constraints 1 and 2 motivate the
FlatParameter abstraction: a 1D tensor that is the concatenation of n flattened original parameters (with optional right-padding to ensure evenness). It serves as FSDP’s atomic unit of communication, and it owns the storage of the constituent original parameters.
When computation involves an original parameter, the owning
FlatParameter must be unsharded and can only be resharded outside the computation. Thus, a performant
FlatParameter construction groups original parameters involved in computation around the same time, and ideally each
FlatParameter is only unsharded/resharded once for a given forward or backward to minimize the number of all-gathers.
Key Question: For a given model, how should FSDP construct
FlatParameter s from the model’s parameters?
FullyShardedDataParallel is a module wrapper, applying transformations to the wrapped module at construction time and runtime (i.e. forward/backward/optimizer step). As an eager API, it only has access to the static module structure at construction time. FSDP wrapping leverages the module structure to inform the
FlatParameter construction, in hope that model authors group parameters with the desired locality into the same module or module subtrees.
Rule 1: If the user wraps
fsdp_module = FullyShardedDataParallel(module) , then every parameter in
module not already flattened is flattened into a single new
FlatParameter and assigned to
Rule 1 is just one way to leverage module structure, which happens to be sufficiently performant while being simple to reason about. From a model print out, a user/developer can quickly infer the parameter assignment. This simplicity can be crucial to debugging and adoption. An example module tree is shown above, where red indicates a
FullyShardedDataParallel -wrapped module and yellow indicates a non-directly-wrapped module. This wrapping constructs four
FlatParameter s (shown by the dotted lines), which are assigned to modules 0, 1, 3, and 7, respectively.
To be clear, if a user only applies
FullyShardedDataParallel to the root module, then there is only a single
FlatParameter consisting of all original parameters, meaning no communication/computation overlap and the entire unsharded parameter size contributes to peak memory. Thus, we need multiple
FlatParameter s, which by Rule 1, requires recursive/nested FSDP wrapping.
Manual wrapping refers to applying
submodule = FullyShardedDataParallel(submodule) on target submodules, where this application proceeds bottom-up. For large/complex models, manual wrapping may not be tractable, so we introduced auto wrapping to use heuristics to perform the nested wrapping automatically. Heuristics include a parameter size threshold and target
nn.Module classes to wrap.
Historical: The original design employed a double module wrapping with a second module wrapper:
FullyShardedDataParallel was responsible for all-gathering, and
FlattenParamsWrapper was responsible for setting the original parameters to be views into the all-gathered
FlatParameter . We consider this all-gather plus view-setting together to be the (logical) unshard .
Constraint 3: For correctness , a module’s parameters must be unsharded during its forward/backward computation and only resharded outside that. For memory performance , the unsharded lifetime should be minimized, and for throughput performance , the number of unshard/reshards should be minimized.
To address the throughput part of Constraint 3:
Rule 2: For a given
FlatParameter and forward/backward pass, FSDP only unshards and reshards the
This ensures the minimal number of all-gathers for a fixed number of
FlatParameter s without changing the algorithm, namely 2x the number of
Given Rules 1 and 2 and Constraint 3, the choice of when to unshard/reshard becomes fixed. For the module that owns a
FlatParameter (unique by Rule 1), FSDP should unshard right before the module’s forward and reshard right after. Similarly, FSDP should unshard right before gradient computation for any tensor in the module’s forward output and reshard after the
FlatParameter 's gradient computation. This defines four points per
FlatParameter : pre-forward unshard, post-forward reshard, pre-backward unshard, and post-backward reshard (in accordance with Rule 2).
FSDP also has some additional logic to run before the entire forward and some to run after the entire backward. These correspond to the root pre-forward and post-backward final callback.
Previously, we saw how an FSDP wrapping informs a
FlatParameter construction. Under the existing design, each
FullyShardedDataParallel instance manages the data for its one
FlatParameter directly (e.g. the pre/post-forward/backward unshard/reshards). However, applying a
FullyShardedDataParallel module wrapper is not actually necessary! I.e., we can achieve the same
FlatParameter construction without imposing on the module structure itself.
Moreover, Rule 1 tightly couples one
FullyShardedDataParallel instance with one
FlatParameter . While this simplifies the design, this is not the best abstraction and can obfuscate that we are intentionally following this simplifying rule. In reality, once we have fixed a
FlatParameter construction, we can have one entity orchestrate the
FlatParameter s (e.g. their unshard/reshards), and we can lower the
FlatParameter data management to another entity that is strictly per-
FlatParameter . For example, the high-level entity can be
FullyShardedDataParallel , but the lower-level entity need not be.
From this insight, we introduced another class,
FlatParamHandle , that exactly performs the
FlatParameter data management and is strictly 1:1 with it.
FullyShardedDataParallel only needs to interface with
FlatParamHandle s. Last year, we refactored FSDP internals to achieve the separation, which resulted in
FlatParameter being nothing more than a plain
Part of lowering the
FlatParameter data management to
FlatParamHandle included enforcing a single code path for unsharding and resharding, respectively, regardless of the calling context (forward/backward, model checkpointing, etc.). This discipline allowed us to augment the unshard and reshard logic in a mostly “just-works” way.
The optimizer step conventionally operates on the registered parameters (returned by
nn.Module.parameters() ). For the existing design, the
FlatParameter s are registered, while the original parameters are de-registered and replaced by plain
Tensor s. Hence, the optimizer step runs on the sharded
FlatParameter s, and the original parameters are lost.
However, when the parameters are sharded, we can still define semantics that follow the single-program multiple-device (SPMD) paradigm expected for data parallelism. In particular, accessing an original parameter on a rank can return the shard that is present in the rank’s
FlatParameter shard or an empty tensor if it is not present, and similarly, the parameter can receive a corresponding gradient if present in the rank’s
FlatParameter shard or
None if not. One caveat is that due to FSDP’s sharding algorithm, we can only return the flattened sharded original parameter, losing the tensor structure.
These semantics to use the original parameters are available today by passing
use_orig_params=True to the FSDP constructor, and they were added exactly by augmenting the existing unshard/reshard logic. In that case,
named_parameters() returns the original fully-qualified names (FQNs), not ones like
<prefix>.flat_param . This enables using multiple optimizer parameter groups and/or different
requires_grad within one
FlatParameter 's original parameters, and this helps hide the
FlatParameter abstraction from users. We hope to converge to setting
use_orig_params=True by default in the future.
The lowering of
FlatParameter data management from
FlatParamHandle and the ability to use the original parameters prepared FSDP for a more composable form. This coincided with a broader eager composable API effort, which imposes a contract requiring each API to preserve the module structure and the original FQNs from
named_parameters() . Thus, FSDP adhering to the contract required the aforementioned work. Without it, there would be a dependency on the
FullyShardedDataParallel module wrapper to manage
FlatParameter data, and the FQNs would fail the contract, respectively.
With the abstractions and semantics aligned, we migrated the pre/post-forward logic to pre/post-forward
nn.Module hooks, generalized existing
FullyShardedDataParallel code over an
_FSDPState object, and landed the resulting
fully_shard API. We hope to provide more details on the composable APIs in the future.
We diagram the design changes above. Originally,
FullyShardedDataParallel was 1:1 with
FlatParameter , which consists of some number of original parameters. We added
FlatParamHandle as a middle layer, as represented by arrow (1). Next, we can imagine replacing
FullyShardedDataParallel with another high-level entity and register hooks on the module instead, as represented by arrow (2). This step only exists logically. Finally, we introduced
fully_shard as the high-level entity, which can be 1:k with respect to modules/
FlatParameter s, as represented by arrow (3).
Here, we focus on two future directions that reflect the need for design flexibility.
Recall that the key question was how to construct
FlatParameter s for a given model. One view is that there can be two (non-disjoint) approaches to further tackle that question:
- Improve our ability to search the existing set of constructions for performant ones.
- Expand our set of possible constructions to include more performant ones.
Following Approach 1 may include providing more guidance on how to apply FSDP or provide an improved auto wrapping policy. Following Approach 2 is more nuanced. For example, we may parameterize the set of constructions by (1) the number of unshard/reshard pairs per
FlatParameter per forward/backward pass and (2) the number of modules involved per unshard/reshard pair. For this note, we focus on two classes in this parameterization:
- Current class: (1 unshard/reshard pair, 1 module per unshard/reshard pair)
- New class: (1 unshard/reshard pair, 2 modules per unshard/reshard pair)
A construction in the new class may be such that one submodule unshards and later a different submodule reshards, hence the 2 modules per unshard/reshard pair. This enables constructions where two sibling modules are grouped together into one
FlatParameter , without including their parent module. Note that such a construction violates Rule 1 (but not 2), which means that the existing
FullyShardedDataParallel wrapper is incompatible.
Last summer, we explored this new class of constructions and found promising throughput and memory gains. However, searching this class and choosing a performant construction pose a challenge. The algorithm we used was the 1st iteration’s execution order to choose the construction, but we have found several blockers to the robustness of this approach.
The above diagram shows the transition from today’s
fully_shard corresponding to (1 unshard/reshard pair, 1 module per unshard/reshair pair) on the left to the one corresponding to (1 unshard/reshard pair, 2 modules per unshard/reshard pair) on the right. Because
fully_shard can orchestrate the
FlatParamHandle s, only needing to register hooks on modules, we may replace each singleton module that is 1:1 with each
FlatParamHandle to be multiple modules that are now n:1 with each
FlatParamHandle , where notably, these modules can be siblings.
In fine-tuning, the majority of model parameters are frozen, i.e. do not require gradient, saving gradient and optimizer state memory. For FSDP, the
FlatParameter construct introduces an issue since it owns the storage of multiple original parameters. For
use_orig_params=False , the user cannot even specify different
requires_grad across original parameters corresponding to a
FlatParameter . For
use_orig_params=True , the specification is possible, but as long as one original parameter receives a gradient, the entire
FlatParameter gets a gradient, where frozen parameters’ gradients manifest as zeros.
We mention one solution. We may relax Rule 1 and enable (up to) two
FlatParameter s per module, one for the parameters that require gradients and for the those that do not. This trades off increased communication overhead (since up to 2x the number of collectives) for decreased gradient memory usage. Our implementation already accommodates multiple
FlatParamHandle s per module except for a few places for model/optimizer state checkpointing, so really, the important questions are around how users should invoke this code path and if the complexity it introduces to the system is justifiable.
The above diagram shows the (up-to) 2
FlatParamHandle s per module setup on the right compared to today’s
fully_shard on the left, which only permits 1
FlatParamHandle per module.
Rethinking the FSDP design from first principles can help expose new opportunities and prepare us for the future. For example, decoupling the sharding factor from the world size previously gave Hybrid Sharded Data Parallel (HSDP). Looking forward, if NCCL offers efficient all-gatherv/reduce-scatterv, relaxing Constraint 1, then FSDP can adopt a new tensor-shape preserving sharding algorithm that allows non-pointwise optimizers like Shampoo. Finally, we may leverage compiler techniques to give us a better way to search for performant
FlatParameter constructions and/or support a richer class of constructions.
More broadly, eager FSDP serves as a learning ground that prepares us to design for more advanced parallelisms and their compositions. FSDP’s
FlatParamHandle simply represent a grouping of tensors sharing a contiguous storage and its data-managing entity, respectively—we may find these fundamental abstractions to be useful beyond FSDP.
Big thanks for the Fairscale FSDP team for creating such a performant and general implementation of ZeRO-3 and for supporting the upstream to PyTorch!