State of model creation/initialization/seralization in PyTorch Core

Cross-posting a note written internally that summarizes the current state of model creation/initialization/loading in PyTorch Core and some intended next steps

Motivation

This post aims to summarize the state of model creation/initialization/loading/serialization in PyTorch core and the asks/proposals received with respect to this. We lay out next steps in the conclusion, there will be a follow up to discuss those.

We break this discussion down into the following categories

  • Skipping initialization of module parameters
  • Loading of state dict from file
  • Loading of parameters from state_dict into model
  • Deferred Initialization

The user-facing API that we discuss here is the following:

m = SomeModule(...)
if use_state_dict: 
    state_dict = torch.load("....pt") 
m.load_state_dict(state_dict)

We attempt to elucidate the proposals we have received with respect to the above snippet.

Skipping initialization of module parameters

Let us first consider the instantiation of the module

m = SomeModule(...)

This line allocates storage for each parameter/buffer in the module and randomly initializes the parameters/buffers per the default initialization scheme for the module.

Problems that motivate skipping initialization

In #29523, there was extensive discussion about skipping random weight initialization. This is because the behavior described above is wasteful when

  1. User creates model on CPU and then moves it to device
m = SomeModule(...).cuda()
  • Initialization speed is slower on CPU than on GPU
  • All randomly initialized buffers have to be copied from CPU memory to device memory
  1. User intends to load params/buffers from the state dict
// 1 copy of randomly initialized parameters in memory
m = SomeModule(...)
// 1 copy of parameters loaded into memory
state_dict = torch.load("foo.pt")
// copy parameters from state dict into parameters in module
m.load_state_dict(state_dict)
  • Wasteful random initialization
  • 2x the parameters in memory (1x from loading state dict and 1x from module init) which increases peak memory
  1. User wants to use an initialization scheme other than the default
m = SomeModule(...)
def weights_init_fn():
   // some custom initialization scheme
m.apply(weights_init_fn)

What has been done for skip initialization

In an attempt to address the above problem we

  1. Utilized the meta device + added a to_empty function from the meta device
  2. Added device/dtype kwargs to all nn.Modules in core

to add a torch.nn.utils.skip_init API which does the following

m = SomeModule(..., device='meta')
m.to_empty(device=final_device)
// can do custom init or init on device after

Afterwards, initialization can be done via torch.nn.init{*} or m.reset_parameters() . (if the module has a reset_parameters method). This API for skip_init is guaranteed to work provided

  1. Module accepts device kwarg in constructor that is passed recursively to all params/buffers
  2. Module does not perform any computation on parameters other than initialization in its constructor

More recently, a torch.device context manager was introduced, so we can now do

with torch.device(‘meta’):
    m = SomeModule(...)
...

This eliminates the need to recursively propagate device kwargs. However, one caveat is that any tensor constructors with their device explicitly set will not be overridden by this context manager. For example, in the following snippet, t will be allocated on cpu rather than meta.

with torch.device('meta'):
    t = torch.randn(2, 5, device='cpu')

Problems with our solution for skip initialization

While this solution does address problems (1) and (3) listed above, it does not address the issue of 2x model parameters being in memory. We will revisit this problem under “Loading of parameters from state_dict into model”.

Per #90645, this approach also does not work well with the nested case of FSDP. In the nested case, FSDP runs initialization bottom-up, which means that children submodules materialize (allocate storage for + initialize) their meta-device parameters before parents. As a result, when the parent module tries to materialize meta-device parameters and calls module.to_empty, the children submodules’ parameters/buffers have already been initialized. However, since the module.to_empty call above affects all children submodules recursively, this means that child submodules that already have their parameters initialized will now have their parameters left as uninitialized memory and have to be reinitialized.

Loading of parameters from the state_dict into the model

m.load_state_dict(state_dict)

Recall from earlier that data from tensors loaded into the state_dict are copied in place from the state_dict loaded into memory into the storage allocated for the parameters/buffers. This means that there will be one set of parameters from the torch.load and one set of parameters from the instance of the model in memory, even when we use the meta device and to_empty trick mentioned in the first section. Another suggestion in #64601 has been to implement a flag that reuses tensors in the state_dict as the model parameters.

m = SomeModule(device='meta')
state_dict = torch.load('foo.pt')
m.load_state_dict(state_dict, reuse_tensors=True)

HuggingFace has custom code that does this here.

Loading of the state dict

state_dict = torch.load("foo.pt")

The main issue here is that the API of torch.load is inflexible and requires loading a copy of the entire state dict into memory. This can be problematic when the state dict is large.

Proposals with respect to torch.load

Selective loading (and saving) of state dicts #97196/ #75242/#64327

Don’t load the full state dict into memory at once, but instead a submodule, tensor (or even tensor slice) at a time.

  1. There was some discussion about the best file format for checkpoints in order to enable this (e.g. dbm).
  2. Some use cases might include
  • Converting multi-gpu partitioned param checkpoints into normal checkpoints (this seems similar to the FSDP state dict conversion issue that was mitigated by using ShardedStateDicts)
  • Converting checkpoints from one framework to another or one dtype to another
  • Post training quantization

Lazy torch.load #79967

m = SomeModule(device='meta')
state_dict = torch.load('foo.pt', lazy=True)
m.load_state_dict(state_dict)
  • For clarification when we say lazy here we aren’t referring to the PyTorch Lazy Tensor project but rather the idea of tensors that are only loaded into memory on demand
  • llama.cpp which attempts to do LLaMA inference in pure cpp has a workaround for this that bypasses the dependency on torch.load here
  • There is also potential to let the state dict be memory mapped

Deferred Initialization

There are also requests to upstream the deferred initialization API from torchdistX into core. This makes all tensors within a given module fake and records operations performed on them. When tensors/modules are materialized, it then replays these operations.

// parameters are all fake / meta
m = deferred_init(SomeModule, ...)
// some code
...
materialize_module(m)
// all real with code replayed
print(m.parameters())

The meta device skip-initialization discussed above does not have this recording behavior.

TorchdistX uses its own cpp implementation of a fake tensor which is distinct from the PT2 FakeTensor. In a related vein, there is a deferred_init subclass in subclass_zoo that aims to address the same issue.

What is this “some code” being recorded

One case where some series of operations needs to be recorded is when initialization is done before sharding and we want the initialization behavior to be replayed. The alternative to tracing might be to do special handling for initialization after sharding, perhaps via implementing a custom reset_parameters for the module.

// let’s say m has some m.weight of shape (2000, 64)
with torch.device(‘meta’):
    m = SomeModule()
// shard across 10 ranks
_do_sharding(m)
// on one rank m.weight has shape (200, 64)
// want to init this slice of the weight
// allocate storage for m.weight
m.to_empty()
// init (200, 64) shard of m.weight (maybe via custom reset_parameters)
m.reset_parameters()

Conclusion

In this note, we have mainly aimed to summarize the current state of things and asks/proposals that we have received in the areas of model creation, initialization and serialization. The natural question here is: what is the plan to solve these problems?

We hope to provide composable building blocks that would address the problems above. In particular, we intend to

  1. Harden meta device context manager
  2. Provide an option for memory-mapped torch.save/torch.load
  3. Improve load_state_dict to mitigate issue of having 2x parameters in memory

Further, we intend to create benchmarks of loading/saving some LLM OSS models from state dict (exact matrix of models TBD). Separately, after we implement these features we would like to onboard some users such as HuggingFace and internal customers onto these features.

We have also alluded to other topics within this note such as that of distributed state handling and deferred initialization. We believe that the building blocks discussed above would help to mitigate some of these issues but we do not intend to explicitly work on them for now.

4 Likes

Excellent summary, @mikaylagawarecki - thank you!

Problems with our solution for skip initialization

You can see how a similar issue has been overcome in Deepspeed to avoid premature sharding:

The comment in the method explains how and why.

Probably the same logic can be used for premature init.

1 Like