[RFC] Introducing FQNs/clarity 👓 to optim state_dict

Context

5+ years ago, a fateful decision was made to keep optimizers and nn modules disjoint. This means that one can use PyTorch optimizers on any list of Tensors–they do not need to be Parameters and do not have to relate to any NN module at all. Consequently, to maintain unawareness, today’s optimizers are indexed with integers (essentially list indices). See a sample state_dict:

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
    },
    'param_groups': [
        {
            'lr': 0.01,
            'weight_decay': 0,
            ...
            'params': [0]
        },
        {
            'lr': 0.001,
            'weight_decay': 0.5,
            ...
            'params': [1, 2, 3]
        }
    ]
}

Note that the keys in the state entry are mere integers. Note that the actual parameters are NOT saved in the state_dict (for memory saving purposes, especially since parameters are expected to be input from somewhere else). Generally, this may be confusing to parse already, and so as more users want to mutate the state_dict themselves, having a robust implementation becomes more important.

Problem

Overall: our state_dict could be made less confusing. As most use cases for optimizers today cannot be detangled from NN modules, we should consider creating a more refined user experience for optimizer state_dicts.

Specifically, the first time this rose as a problem was in checkpointing. Users usually save a state_dict (just like the one displayed above) and then later will recreate an optimizer instance with their own supplied parameters. Today’s scheme will assume the user has input parameters in the exact order as the parameters from the saved instance. This is not bad yet, as most people retrieve parameters from model.parameters() and that returns deterministically.

However, if you have multiple optimizers working on disjoint parameters, for example in any sharding solution such as torchrec or FSDP, each optimizer’s state_dict will share the same indices of 0, 1, 2, etc. (Imagine the above state_dict but like N times, for N shards.) When checkpointing and resharding, these sharding solutions will need to reconcile the duplicated keys by maintaining their own external mapping of parameters to some ID. People have worked around the lack of uniquely identifying keys with concepts like KeyedOptimizer, NamedOptimizer, and various other functions, so this problem already has workarounds, but it could still be worth simplifying the core state_dict for sake of clarity!

A big reason why we haven’t immediately jumped to resolve this issue is to uphold backwards compatibility with regards to state_dicts. Since the decision to decouple module and optimizer was made 5 years ago, there are many uses of optimizer state_dict today that require careful handling. Thus, any solution proposed would need to be warily examined and ensured to not break existing APIs.

Proposal

We want our end state_dict to have customizable unique keys, like:

{
    'state': {
        'linear0.bias': {'momentum_buffer': tensor(...), ...},
        'linear0.weight': {'momentum_buffer': tensor(...), ...},
        'linear1.bias': {'momentum_buffer': tensor(...), ...},
        'linear1.weight': {'momentum_buffer': tensor(...), ...}
    },
    'param_groups': [
        {
            'lr': 0.01,
            'weight_decay': 0.5,
            ...
            'params': ['linear0.bias']
        },
        {
            'lr': 0.001,
            'weight_decay': 0.0,
            ...
            'params': ['linear0.weight', 'linear1.bias', 'linear1.weight']
        }
    ]
}

We can enable having this state_dict by allowing the optimizer constructor to take in named parameters and give users the control over the names. Concretely:

model = …
optim = torch.optim.AdamW(model.named_parameters: Iterator[Tuple[str, Parameter]], lr=0.001)

If a user passes in named parameters, we will respect the names and maintain them in the state. Otherwise, we will default to using integers to maintain BCness. Likewise, if someone calls load_state_dict(new_named_state), we will preserve the names. This proposal would seep through all layers of state from initialization to saving to reloading, but we believe this should not break BC.

This is not an original idea. Thanks to the distributed team + others who have made suggestions and continued pushing for a clearer state_dict, and thank you for reading :smiley:

Would this feature benefit you? Do you have feedback on the current proposed solution? Please leave a comment about your use case!

APPENDIX: Anticipated Questions

Wait a moment…we had talked about pytrees. Why not accept pytrees in general?
There was a hot moment where we explored accepting pytrees (i.e., any container format, you can imagine nested dictionaries, etc.) of parameters, which would enable even more flexibility. However, since distributed checkpointing use cases prefer having a flattened dictionary AND people commonly retrieve parameters through model.named_parameters(), there isn’t sufficient reason to get involved with pytrees. It would take nontrivial work to support pytrees throughout the load and save stack, and the overhead of maintaining them is not insignificant. Accepting pytrees would help us rework our param_group API, though our current param_group API isn’t problematic enough to justify a redo. It would also be a natural extension of the current proposal to enable pytree support, so it can be an option to consider further down the road.

So when will this be implemented?
This feature is not yet slotted on my roadmap, but your feedback may influence its priority. I am especially curious if there are broader use cases beyond distributed checkpointing and torchrec, who already have worked around their issues.

1 Like

Just to state another valid solution – is this an opportunity to introduce a torch.optim2, that cleans up some of these problems with the existing APIs, while not giving a damn about backward-compatibility (because we use a new namespace)?
Or is the change not sufficiently new and innovative that the expensiveness of introducing a v2 API is not warranted?

It’s less that the change isn’t new enough or innovative enough, but more that we are able to solve this concern without the whiplash/detriment of introducing a ‘v2’ API. When we had discussed the option of 2.0 to address broader concerns supersetting this one about state_dict clarity, we concluded we could already make significant improvement to the existing API without the need of breaking BC.

More details are in my internal [meta-only] design doc: https://docs.google.com/document/d/1JJhRCl8F51nH_Ke8Yd_BV3scAmv5V4eSVnle5D8__po/edit#heading=h.gbyekyngph1y

Some notes from new discussions:

TL;DR:

Since we may want to consider supporting pytrees in the future, we will make a slight tweak to the API above, taking in a flattened Dict of string names to parameters instead of an Iterator of tuples.
So, the new API:

optim = torch.optim.AdamW(dict(model1.named_parameters()): Dict[str, Tensor], ...)

We do not want to go all the way in supporting pytrees yet due to its significant effect on state structure for checkpointing and the param groups API.

Details below:

One reason to consider pytrees is to define a simple way to pass in parameters from multiple models, like:

optim = torch.optim.AdamW({
    "model1": model1.named_parameters(),
    "model2": model2.named_parameters(),
}, ...)

However, while this is not particularly hard to handle in doing computation (we can just call tree_flatten), this will affect the state structure: the keys will be going from integers to a nested structure, rather than the current integers to strings. Checkpointing solutions can already handle strings in the state_dict but will have to figure out how to deal with nested structures if we introduce them. Moreover, switching to pytrees will require reconsidering the param_groups API. In the current proposal, we would just replace every instance of parameters with named parameters, like so:

optim = torch.optim.AdamW([
    {
          model1.named_parameters(), lr=0.01
    }, 
          model2.named_parameters(), weight_decay=0.9
    },
    ...
])

When we consider pytrees, we have two options. We could blindly swap out parameters for pytrees again:

optim = torch.optim.AdamW([
    {
          params1: pytree, lr=0.01
    }, 
          params2: pytree, weight_decay=0.9
    },
    ...
])

Note that the params are now separated, meaning we no longer take advantage of the easy pytree definition–users will have to figure out how they want to group. An alternative is to take in pytrees for every hyperparameter, which then forces users to build out pytrees:

optim = torch.optim.AdamW(params: pytree, lrs: pytree, weight_decays: pytree, ...)

Since both alternatives will force users to confront pytrees or do extra processing anyway, sticking with a flattened dictionary API for now seems the simplest solution. We then propose:

optim = torch.optim.AdamW(flattened_named_parameters: Dict[str, Tensor], ...)

The user can either build their own flattened dictionaries, or they could take advantage of nn module structures to provide flattened named_parameters(). We also propose on switching from Iterator to Dict to give ourselves a doorway to pytrees later down the road.