[RFC] Introducing FQNs/clarity đź‘“ to optim state_dict

Some notes from new discussions:

TL;DR:

Since we may want to consider supporting pytrees in the future, we will make a slight tweak to the API above, taking in a flattened Dict of string names to parameters instead of an Iterator of tuples.
So, the new API:

optim = torch.optim.AdamW(dict(model1.named_parameters()): Dict[str, Tensor], ...)

We do not want to go all the way in supporting pytrees yet due to its significant effect on state structure for checkpointing and the param groups API.

Details below:

One reason to consider pytrees is to define a simple way to pass in parameters from multiple models, like:

optim = torch.optim.AdamW({
    "model1": model1.named_parameters(),
    "model2": model2.named_parameters(),
}, ...)

However, while this is not particularly hard to handle in doing computation (we can just call tree_flatten), this will affect the state structure: the keys will be going from integers to a nested structure, rather than the current integers to strings. Checkpointing solutions can already handle strings in the state_dict but will have to figure out how to deal with nested structures if we introduce them. Moreover, switching to pytrees will require reconsidering the param_groups API. In the current proposal, we would just replace every instance of parameters with named parameters, like so:

optim = torch.optim.AdamW([
    {
          model1.named_parameters(), lr=0.01
    }, 
          model2.named_parameters(), weight_decay=0.9
    },
    ...
])

When we consider pytrees, we have two options. We could blindly swap out parameters for pytrees again:

optim = torch.optim.AdamW([
    {
          params1: pytree, lr=0.01
    }, 
          params2: pytree, weight_decay=0.9
    },
    ...
])

Note that the params are now separated, meaning we no longer take advantage of the easy pytree definition–users will have to figure out how they want to group. An alternative is to take in pytrees for every hyperparameter, which then forces users to build out pytrees:

optim = torch.optim.AdamW(params: pytree, lrs: pytree, weight_decays: pytree, ...)

Since both alternatives will force users to confront pytrees or do extra processing anyway, sticking with a flattened dictionary API for now seems the simplest solution. We then propose:

optim = torch.optim.AdamW(flattened_named_parameters: Dict[str, Tensor], ...)

The user can either build their own flattened dictionaries, or they could take advantage of nn module structures to provide flattened named_parameters(). We also propose on switching from Iterator to Dict to give ourselves a doorway to pytrees later down the road.