Torch.optim: Adafactor, those FQNs you wanted, and thanks to our community! (2024 H2)

This halfly update is only made possible through viewers contributors like you, thank you:

  • @ErezYosef implemented tracking parameter names/FQNs (named_parameters) in Optimizer state_dict, i.e., you can now do optim = torch.optim.AdamW(model.named_parameters(), …) and the FQNs will be trackable in optim.state_dict()!
  • A few of you helped widen support for hyperparameters (think LR, betas, weight_decay) being Tensors instead of floats, whether to improve speed through torch.compile() and CUDA graphs, or to differentiate through the optimizer.
  • Adafactor, well, I added this, but @crcrpar also took up the call to make Adafactor faster by landing _foreach_rsqrt and a new _foreach_lerp overload.
  • And many of you made Optimizer and LRScheduler faster, more composable, and better tested overall!

So in detail:

FQNs in the state_dict

The optimizer state_dict has been admittedly confusing, using integer keys to associate input parameters instead of FQNs like in nn.Module state_dict.

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
},
    'param_groups': [{
        'lr': 0.01,
        'weight_decay': 0,
        ...
        'params': [0]
    },
    {
        'lr': 0.001,
        'weight_decay': 0.9,
        ...
        'params': [1, 2, 3]
    }]
}

Even from the start of time (2017), people had complained in #1489 (21st oldest issue out of 14k in the repo :D). Last year, we had discussed some proposals on dev-discuss, for example to swap the integer keys for FQNs from model.named_parameters(), but ultimately did not implement as it would likely require changes from other parts of the pipeline (e.g., checkpointing). @ErezYosef’s approach skirts around these concerns by adding a new param_names association in the state_dict when model.named_parameters() are provided (see contrived example state_dict below):

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
},
    'param_groups': [{
        'lr': 0.01,
        'weight_decay': 0,
        ...
        'params': [0],
        'param_names': ['layer1.bias']  # NEW!
    },
    {
        'lr': 0.001,
        'weight_decay': 0.9,
        ...
        'params': [1, 2, 3],
        'param_names': ['layer1.weight', 'layer2.weight', 'layer3.weight']  # NEW!
    }]
}

This will improve checkpointing and debugging use cases and is a gentle stepping stone towards when we eventually swap out the integer keys.

Introducing our first param-wise optimizer, Adafactor

Adafactor’s nascence came from the need for lower memory optimizers, as the incumbent Adam(W) optimizers require memory for 2! additional states per trainable parameter: momentum and variance. Adafactor replaces both states with a row factor and a column factor, effectively going from requiring 2 x M x N space to M + N space (if a parameter is a matrix M by N). Adafactor implementations have long existed outside of torch.optim, but now you can specify it with optim = torch.optim.Adafactor(params, …). Try it out and let us know the problems! In fact there’s already room to improve; I plan to address the factored dims optimizations brought up in rosswightman’s tweets (and also see our performance tracker).

Tensor LR, betas, weight_decay

Last year, a sizable hindrance to using torch.compile() or CUDA graphs to make optimizer.step() faster was that a changing LR (learning rate), like when you used a LRScheduler, would cause the computation graph to be recompiled. We had fixed this, to some extent! A similar problem exists for other hyperparameters, like betas. In order to better encapsulate dynamism in optimizer hyperparameters, one solution is to use scalar Tensors over Python floats. Thanks to work from @mlazos, @qqaatw, and Intel, we have

  • Tensor LR support across all capturable implementations (used in compile and graph capture) for Apple MPS and Intel XPU/HPU (Adam(W), Adadelta, Adagrad, Adamax, NAdam(W), RAdam(W), RMSProp, Rprop, SGD, ASGD)
  • Tensor LR support for LBFGS and SparseAdam
  • Tensor betas support for the capturable implementations for Adam(W)

Sometimes, Tensorfying the hyperparameters isn’t enough. In particular, when attempting to train on hyperparameters and differentiate through the optimizer step itself, we have to support Tensor hyperparameters that require grad in the optimizer too! In this regard, we give a shoutout to @EmmettBicker who:

Bullet List of More Awesome Contributions

Speed

Increased Reliability

Clearer Docs

  • @spzala fixed our torch.optim docs to make LRScheduler and swa_util content appear
  • We clarified the defaulting behavior in optimizers. To be clear, we default to foreach when on CUDA for most optimizers!

What is next?

  • Default to the fastest fused implementations now that they’ve baked for a while
  • Improve Adafactor
  • Land @youssef62’s PR to use less memory when beta1 = 0
  • Increase optimizer differentiability with @EmmettBicker
  • What do you want to see next? Post your thoughts below and leave issue links here!
3 Likes