This halfly update is only made possible through viewers contributors like you, thank you:
- @ErezYosef implemented tracking parameter names/FQNs (named_parameters) in Optimizer state_dict, i.e., you can now do
optim = torch.optim.AdamW(model.named_parameters(), …)
and the FQNs will be trackable inoptim.state_dict()
! - A few of you helped widen support for hyperparameters (think LR, betas, weight_decay) being Tensors instead of floats, whether to improve speed through torch.compile() and CUDA graphs, or to differentiate through the optimizer.
- Adafactor, well, I added this, but @crcrpar also took up the call to make Adafactor faster by landing _foreach_rsqrt and a new _foreach_lerp overload.
- And many of you made Optimizer and LRScheduler faster, more composable, and better tested overall!
So in detail:
FQNs in the state_dict
The optimizer state_dict has been admittedly confusing, using integer keys to associate input parameters instead of FQNs like in nn.Module state_dict.
{
'state': {
0: {'momentum_buffer': tensor(...), ...},
1: {'momentum_buffer': tensor(...), ...},
2: {'momentum_buffer': tensor(...), ...},
3: {'momentum_buffer': tensor(...), ...}
},
'param_groups': [{
'lr': 0.01,
'weight_decay': 0,
...
'params': [0]
},
{
'lr': 0.001,
'weight_decay': 0.9,
...
'params': [1, 2, 3]
}]
}
Even from the start of time (2017), people had complained in #1489 (21st oldest issue out of 14k in the repo :D). Last year, we had discussed some proposals on dev-discuss, for example to swap the integer keys for FQNs from model.named_parameters(), but ultimately did not implement as it would likely require changes from other parts of the pipeline (e.g., checkpointing). @ErezYosef’s approach skirts around these concerns by adding a new param_names
association in the state_dict when model.named_parameters()
are provided (see contrived example state_dict below):
{
'state': {
0: {'momentum_buffer': tensor(...), ...},
1: {'momentum_buffer': tensor(...), ...},
2: {'momentum_buffer': tensor(...), ...},
3: {'momentum_buffer': tensor(...), ...}
},
'param_groups': [{
'lr': 0.01,
'weight_decay': 0,
...
'params': [0],
'param_names': ['layer1.bias'] # NEW!
},
{
'lr': 0.001,
'weight_decay': 0.9,
...
'params': [1, 2, 3],
'param_names': ['layer1.weight', 'layer2.weight', 'layer3.weight'] # NEW!
}]
}
This will improve checkpointing and debugging use cases and is a gentle stepping stone towards when we eventually swap out the integer keys.
Introducing our first param-wise optimizer, Adafactor
Adafactor’s nascence came from the need for lower memory optimizers, as the incumbent Adam(W) optimizers require memory for 2! additional states per trainable parameter: momentum and variance. Adafactor replaces both states with a row factor and a column factor, effectively going from requiring 2 x M x N space to M + N space (if a parameter is a matrix M by N). Adafactor implementations have long existed outside of torch.optim, but now you can specify it with optim = torch.optim.Adafactor(params, …)
. Try it out and let us know the problems! In fact there’s already room to improve; I plan to address the factored dims optimizations brought up in rosswightman’s tweets (and also see our performance tracker).
Tensor LR, betas, weight_decay
Last year, a sizable hindrance to using torch.compile() or CUDA graphs to make optimizer.step() faster was that a changing LR (learning rate), like when you used a LRScheduler, would cause the computation graph to be recompiled. We had fixed this, to some extent! A similar problem exists for other hyperparameters, like betas. In order to better encapsulate dynamism in optimizer hyperparameters, one solution is to use scalar Tensors over Python floats. Thanks to work from @mlazos, @qqaatw, and Intel, we have
- Tensor LR support across all capturable implementations (used in compile and graph capture) for Apple MPS and Intel XPU/HPU (Adam(W), Adadelta, Adagrad, Adamax, NAdam(W), RAdam(W), RMSProp, Rprop, SGD, ASGD)
- Tensor LR support for LBFGS and SparseAdam
- Tensor betas support for the capturable implementations for Adam(W)
Sometimes, Tensorfying the hyperparameters isn’t enough. In particular, when attempting to train on hyperparameters and differentiate through the optimizer step itself, we have to support Tensor hyperparameters that require grad in the optimizer too! In this regard, we give a shoutout to @EmmettBicker who:
- enabled support for differentiable lr and weight_decay in SGD
- added differentiable lr, weight_decay and betas in Adam(W)
- is working on expanding this support to more optimizers, see our plan
Bullet List of More Awesome Contributions
Speed
- @alpha0422 removed unnecessary H2D Sync in LRScheduler when using Tensor LR
- I made _foreach_norm faster by removing dispatch overhead.
- @Shan19900305 fixed size 1 tensors to ignore stride check for foreach ops, enabling more parameters to enroll in faster runtime through horizontal fusion
Increased Reliability
- Empty tensors struck again, causing inaccuracy in foreach reductions, and @ngimel debugged this for hopefully the very last time.
- I fixed a bug in ReduceLROnPlateau interaction with add_param_group where modifying your Optimizer param groups would error
- AdamW has been refactored to reuse the majority of Adam code, eliminating ~700 lines of duplication, thanks to @tfsingh and @EmmettBicker
- ChainedScheduler can be nested within a SequentialLR now thanks to @mattpitkin
- @crcrpar moved the params’ device check from the constructor to later to allow more use cases with meta device params
- The community helped increase optim test coverage using new OptimInfos in solving #123451
Clearer Docs
- @spzala fixed our torch.optim docs to make LRScheduler and swa_util content appear
- We clarified the defaulting behavior in optimizers. To be clear, we default to foreach when on CUDA for most optimizers!
What is next?
- Default to the fastest fused implementations now that they’ve baked for a while
- Improve Adafactor
- Land @youssef62’s PR to use less memory when beta1 = 0
- Increase optimizer differentiability with @EmmettBicker
- What do you want to see next? Post your thoughts below and leave issue links here!