Context
5+ years ago, a fateful decision was made to keep optimizers and nn modules disjoint. This means that one can use PyTorch optimizers on any list of Tensors–they do not need to be Parameters and do not have to relate to any NN module at all. Consequently, to maintain unawareness, today’s optimizers are indexed with integers (essentially list indices). See a sample state_dict:
{
'state': {
0: {'momentum_buffer': tensor(...), ...},
1: {'momentum_buffer': tensor(...), ...},
2: {'momentum_buffer': tensor(...), ...},
3: {'momentum_buffer': tensor(...), ...}
},
'param_groups': [
{
'lr': 0.01,
'weight_decay': 0,
...
'params': [0]
},
{
'lr': 0.001,
'weight_decay': 0.5,
...
'params': [1, 2, 3]
}
]
}
Note that the keys in the state entry are mere integers. Note that the actual parameters are NOT saved in the state_dict (for memory saving purposes, especially since parameters are expected to be input from somewhere else). Generally, this may be confusing to parse already, and so as more users want to mutate the state_dict themselves, having a robust implementation becomes more important.
Problem
Overall: our state_dict could be made less confusing. As most use cases for optimizers today cannot be detangled from NN modules, we should consider creating a more refined user experience for optimizer state_dicts.
Specifically, the first time this rose as a problem was in checkpointing. Users usually save a state_dict (just like the one displayed above) and then later will recreate an optimizer instance with their own supplied parameters. Today’s scheme will assume the user has input parameters in the exact order as the parameters from the saved instance. This is not bad yet, as most people retrieve parameters from model.parameters()
and that returns deterministically.
However, if you have multiple optimizers working on disjoint parameters, for example in any sharding solution such as torchrec or FSDP, each optimizer’s state_dict will share the same indices of 0, 1, 2, etc. (Imagine the above state_dict but like N times, for N shards.) When checkpointing and resharding, these sharding solutions will need to reconcile the duplicated keys by maintaining their own external mapping of parameters to some ID. People have worked around the lack of uniquely identifying keys with concepts like KeyedOptimizer, NamedOptimizer, and various other functions, so this problem already has workarounds, but it could still be worth simplifying the core state_dict for sake of clarity!
A big reason why we haven’t immediately jumped to resolve this issue is to uphold backwards compatibility with regards to state_dicts. Since the decision to decouple module and optimizer was made 5 years ago, there are many uses of optimizer state_dict today that require careful handling. Thus, any solution proposed would need to be warily examined and ensured to not break existing APIs.
Proposal
We want our end state_dict to have customizable unique keys, like:
{
'state': {
'linear0.bias': {'momentum_buffer': tensor(...), ...},
'linear0.weight': {'momentum_buffer': tensor(...), ...},
'linear1.bias': {'momentum_buffer': tensor(...), ...},
'linear1.weight': {'momentum_buffer': tensor(...), ...}
},
'param_groups': [
{
'lr': 0.01,
'weight_decay': 0.5,
...
'params': ['linear0.bias']
},
{
'lr': 0.001,
'weight_decay': 0.0,
...
'params': ['linear0.weight', 'linear1.bias', 'linear1.weight']
}
]
}
We can enable having this state_dict by allowing the optimizer constructor to take in named parameters and give users the control over the names. Concretely:
model = …
optim = torch.optim.AdamW(model.named_parameters: Iterator[Tuple[str, Parameter]], lr=0.001)
If a user passes in named parameters, we will respect the names and maintain them in the state. Otherwise, we will default to using integers to maintain BCness. Likewise, if someone calls load_state_dict(new_named_state)
, we will preserve the names. This proposal would seep through all layers of state from initialization to saving to reloading, but we believe this should not break BC.
This is not an original idea. Thanks to the distributed team + others who have made suggestions and continued pushing for a clearer state_dict, and thank you for reading
Would this feature benefit you? Do you have feedback on the current proposed solution? Please leave a comment about your use case!
APPENDIX: Anticipated Questions
Wait a moment…we had talked about pytrees. Why not accept pytrees in general?
There was a hot moment where we explored accepting pytrees (i.e., any container format, you can imagine nested dictionaries, etc.) of parameters, which would enable even more flexibility. However, since distributed checkpointing use cases prefer having a flattened dictionary AND people commonly retrieve parameters through model.named_parameters()
, there isn’t sufficient reason to get involved with pytrees. It would take nontrivial work to support pytrees throughout the load and save stack, and the overhead of maintaining them is not insignificant. Accepting pytrees would help us rework our param_group API, though our current param_group API isn’t problematic enough to justify a redo. It would also be a natural extension of the current proposal to enable pytree support, so it can be an option to consider further down the road.
So when will this be implemented?
This feature is not yet slotted on my roadmap, but your feedback may influence its priority. I am especially curious if there are broader use cases beyond distributed checkpointing and torchrec, who already have worked around their issues.