Given some interest, I am sharing a note (first written internally) on the PyTorch Fully Sharded Data Parallel (FSDP) design. This covers much but not all of it (e.g. it excludes autograd and CUDA caching allocator interaction). I can share more details if there is further interest.
TL;DR
- We rethought the PyTorch FSDP design from first principles to uncover a new one that takes a first step toward improving composability and flexibility.
- This includes an experimental
fully_shard
API that is part of a broader eager distributed composable API effort. This is a work in progress and not ready for general use yet. - This includes the ability for the existing
FullyShardedDataParallel
to expose the original parameters (notFlatParameter
s) viause_orig_params=True
, which enables flexible support for multiple parameter groups.
Introduction
In data parallelism , parameters are replicated across ranks, and each rank operates on a local batch, computing local gradients. Ranks all-reduce their local gradient to compute the global gradient and run a replicated optimizer step.
In sharded data parallelism , parameters are sharded across ranks, and ranks all-gather them as needed for unsharded forward/backward computation and free them when unneeded. Since each rank only manages a shard of the parameters, each rank only needs the corresponding gradient shard to run the optimizer step. This implies that ranks reduce-scatter their local unsharded gradient to compute the sharded global gradient and to run a sharded optimizer step.
In sharded data parallelism, only the memory of the sharded parameters, sharded gradients, and sharded optimizer states plus some memory for the largest subset of simultaneously unsharded parameters contribute to a rank’s peak memory. These memory savings may be critical for training large models or scaling batch size.
In fully sharded data parallelism, the sharding factor is specifically set to the world size. I.e., the parameters are sharded across all ranks without replication. This algorithm is commonly called ZeRO-3, and PyTorch’s Fully Sharded Data Parallel (FSDP) is one implementation, where a central challenge is working within the PyTorch framework. (The sharding factor need not be the world size; setting it to be the number of intra-node devices gives the alternative Hybrid Sharded Data Parallel (HSDP).)
PyTorch FSDP was upstreamed from Fairscale FSDP. This note provides one possible perspective on this evolving FSDP design from first principles and motivates a modified design striving for improved composability. Throughout, we treat performance as a first-class constraint since a non-performant design is not usable.
Constraints: Communication
FSDP targets Nvidia GPUs and uses NCCL for collective communications, and FSDP requires all-gather to unshard parameters and reduce-scatter to reduce gradients. PyTorch’s all_gather_into_tensor
requires even input sizes across ranks but is more performant than all_gather
and similarly for reduce_scatter_tensor
vs. reduce_scatter
. Moreover, for a fixed communication volume, batching data and issuing fewer collectives is more performant. This yields the following two constraints:
Constraint 1: FSDP should communicate even sizes across ranks to use all_gather_into_tensor
/ reduce_scatter_tensor
.
Constraint 2: FSDP should batch parameters for all-gather and gradients for reduce-scatter.
FlatParameter
Constraints 1 and 2 motivate the FlatParameter
abstraction: a 1D tensor that is the concatenation of n flattened original parameters (with optional right-padding to ensure evenness). It serves as FSDP’s atomic unit of communication, and it owns the storage of the constituent original parameters.
When computation involves an original parameter, the owning FlatParameter
must be unsharded and can only be resharded outside the computation. Thus, a performant FlatParameter
construction groups original parameters involved in computation around the same time, and ideally each FlatParameter
is only unsharded/resharded once for a given forward or backward to minimize the number of all-gathers.
Key Question: For a given model, how should FSDP construct FlatParameter
s from the model’s parameters?
FlatParameter
Construction
FullyShardedDataParallel
is a module wrapper, applying transformations to the wrapped module at construction time and runtime (i.e. forward/backward/optimizer step). As an eager API, it only has access to the static module structure at construction time. FSDP wrapping leverages the module structure to inform the FlatParameter
construction, in hope that model authors group parameters with the desired locality into the same module or module subtrees.
Rule 1: If the user wraps fsdp_module = FullyShardedDataParallel(module)
, then every parameter in module
not already flattened is flattened into a single new FlatParameter
and assigned to fsdp_module
.
Rule 1 is just one way to leverage module structure, which happens to be sufficiently performant while being simple to reason about. From a model print out, a user/developer can quickly infer the parameter assignment. This simplicity can be crucial to debugging and adoption. An example module tree is shown above, where red indicates a FullyShardedDataParallel
-wrapped module and yellow indicates a non-directly-wrapped module. This wrapping constructs four FlatParameter
s (shown by the dotted lines), which are assigned to modules 0, 1, 3, and 7, respectively.
To be clear, if a user only applies FullyShardedDataParallel
to the root module, then there is only a single FlatParameter
consisting of all original parameters, meaning no communication/computation overlap and the entire unsharded parameter size contributes to peak memory. Thus, we need multiple FlatParameter
s, which by Rule 1, requires recursive/nested FSDP wrapping.
Manual & Auto Wrapping
Manual wrapping refers to applying submodule = FullyShardedDataParallel(submodule)
on target submodules, where this application proceeds bottom-up. For large/complex models, manual wrapping may not be tractable, so we introduced auto wrapping to use heuristics to perform the nested wrapping automatically. Heuristics include a parameter size threshold and target nn.Module
classes to wrap.
Historical: The original design employed a double module wrapping with a second module wrapper: FullyShardedDataParallel(FlattenParamsWrapper(module))
. FullyShardedDataParallel
was responsible for all-gathering, and FlattenParamsWrapper
was responsible for setting the original parameters to be views into the all-gathered FlatParameter
. We consider this all-gather plus view-setting together to be the (logical) unshard .
Constraints: Unsharding Parameters for Computation
Constraint 3: For correctness , a module’s parameters must be unsharded during its forward/backward computation and only resharded outside that. For memory performance , the unsharded lifetime should be minimized, and for throughput performance , the number of unshard/reshards should be minimized.
Pre/Post-Forward/Backward
To address the throughput part of Constraint 3:
Rule 2: For a given FlatParameter
and forward/backward pass, FSDP only unshards and reshards the FlatParameter
once.
This ensures the minimal number of all-gathers for a fixed number of FlatParameter
s without changing the algorithm, namely 2x the number of FlatParameter
s.
Given Rules 1 and 2 and Constraint 3, the choice of when to unshard/reshard becomes fixed. For the module that owns a FlatParameter
(unique by Rule 1), FSDP should unshard right before the module’s forward and reshard right after. Similarly, FSDP should unshard right before gradient computation for any tensor in the module’s forward output and reshard after the FlatParameter
's gradient computation. This defines four points per FlatParameter
: pre-forward unshard, post-forward reshard, pre-backward unshard, and post-backward reshard (in accordance with Rule 2).
FSDP also has some additional logic to run before the entire forward and some to run after the entire backward. These correspond to the root pre-forward and post-backward final callback.
New Abstraction: FlatParamHandle
Previously, we saw how an FSDP wrapping informs a FlatParameter
construction. Under the existing design, each FullyShardedDataParallel
instance manages the data for its one FlatParameter
directly (e.g. the pre/post-forward/backward unshard/reshards). However, applying a FullyShardedDataParallel
module wrapper is not actually necessary! I.e., we can achieve the same FlatParameter
construction without imposing on the module structure itself.
Moreover, Rule 1 tightly couples one FullyShardedDataParallel
instance with one FlatParameter
. While this simplifies the design, this is not the best abstraction and can obfuscate that we are intentionally following this simplifying rule. In reality, once we have fixed a FlatParameter
construction, we can have one entity orchestrate the FlatParameter
s (e.g. their unshard/reshards), and we can lower the FlatParameter
data management to another entity that is strictly per- FlatParameter
. For example, the high-level entity can be FullyShardedDataParallel
, but the lower-level entity need not be.
From this insight, we introduced another class, FlatParamHandle
, that exactly performs the FlatParameter
data management and is strictly 1:1 with it. FullyShardedDataParallel
only needs to interface with FlatParamHandle
s. Last year, we refactored FSDP internals to achieve the separation, which resulted in FlatParameter
being nothing more than a plain nn.Parameter
.
New Feature: use_orig_params=True
Part of lowering the FlatParameter
data management to FlatParamHandle
included enforcing a single code path for unsharding and resharding, respectively, regardless of the calling context (forward/backward, model checkpointing, etc.). This discipline allowed us to augment the unshard and reshard logic in a mostly “just-works” way.
The optimizer step conventionally operates on the registered parameters (returned by nn.Module.parameters()
). For the existing design, the FlatParameter
s are registered, while the original parameters are de-registered and replaced by plain Tensor
s. Hence, the optimizer step runs on the sharded FlatParameter
s, and the original parameters are lost.
However, when the parameters are sharded, we can still define semantics that follow the single-program multiple-device (SPMD) paradigm expected for data parallelism. In particular, accessing an original parameter on a rank can return the shard that is present in the rank’s FlatParameter
shard or an empty tensor if it is not present, and similarly, the parameter can receive a corresponding gradient if present in the rank’s FlatParameter
shard or None
if not. One caveat is that due to FSDP’s sharding algorithm, we can only return the flattened sharded original parameter, losing the tensor structure.
These semantics to use the original parameters are available today by passing use_orig_params=True
to the FSDP constructor, and they were added exactly by augmenting the existing unshard/reshard logic. In that case, named_parameters()
returns the original fully-qualified names (FQNs), not ones like <prefix>.flat_param
. This enables using multiple optimizer parameter groups and/or different requires_grad
within one FlatParameter
's original parameters, and this helps hide the FlatParameter
abstraction from users. We hope to converge to setting use_orig_params=True
by default in the future.
New API: From Wrapper to No Wrapper – fully_shard
The lowering of FlatParameter
data management from FullyShardedDataParallel
to FlatParamHandle
and the ability to use the original parameters prepared FSDP for a more composable form. This coincided with a broader eager composable API effort, which imposes a contract requiring each API to preserve the module structure and the original FQNs from named_parameters()
. Thus, FSDP adhering to the contract required the aforementioned work. Without it, there would be a dependency on the FullyShardedDataParallel
module wrapper to manage FlatParameter
data, and the FQNs would fail the contract, respectively.
With the abstractions and semantics aligned, we migrated the pre/post-forward logic to pre/post-forward nn.Module
hooks, generalized existing FullyShardedDataParallel
code over an _FSDPState
object, and landed the resulting fully_shard
API. We hope to provide more details on the composable APIs in the future.
We diagram the design changes above. Originally, FullyShardedDataParallel
was 1:1 with FlatParameter
, which consists of some number of original parameters. We added FlatParamHandle
as a middle layer, as represented by arrow (1). Next, we can imagine replacing FullyShardedDataParallel
with another high-level entity and register hooks on the module instead, as represented by arrow (2). This step only exists logically. Finally, we introduced fully_shard
as the high-level entity, which can be 1:k with respect to modules/FlatParamHandle
s/FlatParameter
s, as represented by arrow (3).
Future Directions
Here, we focus on two future directions that reflect the need for design flexibility.
Revisiting FlatParameter
Construction
Recall that the key question was how to construct FlatParameter
s for a given model. One view is that there can be two (non-disjoint) approaches to further tackle that question:
- Improve our ability to search the existing set of constructions for performant ones.
- Expand our set of possible constructions to include more performant ones.
Following Approach 1 may include providing more guidance on how to apply FSDP or provide an improved auto wrapping policy. Following Approach 2 is more nuanced. For example, we may parameterize the set of constructions by (1) the number of unshard/reshard pairs per FlatParameter
per forward/backward pass and (2) the number of modules involved per unshard/reshard pair. For this note, we focus on two classes in this parameterization:
- Current class: (1 unshard/reshard pair, 1 module per unshard/reshard pair)
- New class: (1 unshard/reshard pair, 2 modules per unshard/reshard pair)
A construction in the new class may be such that one submodule unshards and later a different submodule reshards, hence the 2 modules per unshard/reshard pair. This enables constructions where two sibling modules are grouped together into one FlatParameter
, without including their parent module. Note that such a construction violates Rule 1 (but not 2), which means that the existing FullyShardedDataParallel
wrapper is incompatible.
Last summer, we explored this new class of constructions and found promising throughput and memory gains. However, searching this class and choosing a performant construction pose a challenge. The algorithm we used was the 1st iteration’s execution order to choose the construction, but we have found several blockers to the robustness of this approach.
The above diagram shows the transition from today’s fully_shard
corresponding to (1 unshard/reshard pair, 1 module per unshard/reshair pair) on the left to the one corresponding to (1 unshard/reshard pair, 2 modules per unshard/reshard pair) on the right. Because fully_shard
can orchestrate the FlatParamHandle
s, only needing to register hooks on modules, we may replace each singleton module that is 1:1 with each FlatParamHandle
to be multiple modules that are now n:1 with each FlatParamHandle
, where notably, these modules can be siblings.
Memory Efficient Fine-Tuning
In fine-tuning, the majority of model parameters are frozen, i.e. do not require gradient, saving gradient and optimizer state memory. For FSDP, the FlatParameter
construct introduces an issue since it owns the storage of multiple original parameters. For use_orig_params=False
, the user cannot even specify different requires_grad
across original parameters corresponding to a FlatParameter
. For use_orig_params=True
, the specification is possible, but as long as one original parameter receives a gradient, the entire FlatParameter
gets a gradient, where frozen parameters’ gradients manifest as zeros.
We mention one solution. We may relax Rule 1 and enable (up to) two FlatParameter
s per module, one for the parameters that require gradients and for the those that do not. This trades off increased communication overhead (since up to 2x the number of collectives) for decreased gradient memory usage. Our implementation already accommodates multiple FlatParamHandle
s per module except for a few places for model/optimizer state checkpointing, so really, the important questions are around how users should invoke this code path and if the complexity it introduces to the system is justifiable.
The above diagram shows the (up-to) 2 FlatParamHandle
s per module setup on the right compared to today’s fully_shard
on the left, which only permits 1 FlatParamHandle
per module.
Final Thoughts
Rethinking the FSDP design from first principles can help expose new opportunities and prepare us for the future. For example, decoupling the sharding factor from the world size previously gave Hybrid Sharded Data Parallel (HSDP). Looking forward, if NCCL offers efficient all-gatherv/reduce-scatterv, relaxing Constraint 1, then FSDP can adopt a new tensor-shape preserving sharding algorithm that allows non-pointwise optimizers like Shampoo. Finally, we may leverage compiler techniques to give us a better way to search for performant FlatParameter
constructions and/or support a richer class of constructions.
More broadly, eager FSDP serves as a learning ground that prepares us to design for more advanced parallelisms and their compositions. FSDP’s FlatParameter
and FlatParamHandle
simply represent a grouping of tensors sharing a contiguous storage and its data-managing entity, respectively—we may find these fundamental abstractions to be useful beyond FSDP.
Big thanks for the Fairscale FSDP team for creating such a performant and general implementation of ZeRO-3 and for supporting the upstream to PyTorch!