A common source of confusion when users encounter the two efforts is to understand and decide which one to use and how they relate to each other.
Introduction
Recent efforts in extremely large models faced the issue that their checkpoints are too big to fit into main memory and have to employ a multi-host checkpointing strategy.
Saving such large models turned out to be a significantly complex task that the desire for common infrastructure emerged in the form of two similar efforts with complementary features.
Common challenges are how to scale up IO to exploit the capabilities of the whole cluster and how to deal with changes in cluster topology across save and load - commonly referred to as resharding.
PyTorch Distributed Checkpointing
Part of PyTorch Distributed, this is a low level API designed to be the infrastructure for users trying to introduce distributed checkpoint on their frameworks.
The API is divided in two layers, one for checkpoint planning and another storage IO. On top of that, there is an implementation of load and save that coordinates the two, dealing with common issues of distributed programming such as synchronization and error propagation.
The checkpoint planning layer translates between the user state_dict and a low level representation of IO reads and writes, while dealing with the issues of resharding and offering extensibility to handle new distributed data schemes. The default implementation of this layer supports ShardedTensor and handling others PyTorch distributed mechanisms like FSDP and DistributedTensor are on our H2 roadmap.
The storage IO layer deals with reading and writing to the underlying storage plus converting the storage-agnostic checkpoint metadata format between the underlying system and PTD checkpoint. The default implementation implements a simple filesystem-based solution aimed at enabling users to try the API; it doesn’t feature any of the advanced performance and memory saving techniques like TorchSnapshot. Our H2 goal is to ensure a consistent baseline performance.
Its focus is to provide a foundation for checkpoint metadata; handling of core PTD building blocks like ShardedTensor and a robust baseline implementation, our H2 focus is to expand it coverage of PTD building blocks to include FSDP, DistributedTensor and 2D parallelism by combining both.
TorchSnapshot
TorchSnapshot provides PyTorch users with easy access to performant and reliable checkpointing functionalities with minimal change to their training scripts. This library is independent of any higher-level training stacks or frameworks. It’s designed to work for a wide variety of training use cases, such as distributed training, in which checkpointing is commonly seen as a huge pain point by users.
Users may consider TorchSnapshot if the following benefits their use case:
-
Performance - TorchSnapshot offers highly optimized checkpointing implementation for common training setups. It outperforms torch.save to varying degrees depending on the use case.
-
Memory usage - TorchSnapshot is adaptive to the amount of host memory available during checkpoint saving/loading process. When memory is constrained, it automatically reduces memory consumption, which leads to reduction of host OOMs during checkpointing. When memory is abundant, it automatically utilizes more memory to achieve higher throughput.
-
Manipulability - TorchSnapshot provides intuitive, efficient, and flexible access to checkpoint content, which allows for processing (e.g., re-sharding, quantizing) large model checkpoints that do not fit in RAM.
-
TorchSnapshot offers performant and reliable integration with commonly used cloud object storage.
Usage Recommendation
Use Cases for PTD Checkpoint
Being a low level building block, it’s best fit for users that need to roll-out their own distributed checkpoint solution, in particular when they have complex data distribution not easily expressible with current primitives such as ShardedTensor.
Use Cases for TorchSnapshot
Users expecting a complete snapshot management solution should use TorchSnapshot.
Guidelines
Our recommended guidelines is that most users should consider TorchSnapshot first as it offers a more user friendly and complete solution.
PTD checkpoint should be used if they need the low level control offered by its API or bleeding edge support for some PTD features.
Collaboration Between the two efforts
We’re actively looking for areas to increase code sharing and our current focus is dealing with the complexities of load-time resharding and establishing a common interface for how distributed models can present their state to checkpointing solutions. Another collaboration stream is around upcoming support for the new DistributedTensor and its usage in tandem with FSDP.