Supporting LRScheduler with Compiled Optimizer
Background
Many users have attempted to use the LRScheduler with the compiled optimizer, only to run into issues with recompiling the optimizer. This was due to torch.compile guarding on the scalar value of the LR, so every time the LRScheduler changed the value of the LR this would trigger a recompile. This is best illustrated by the following example.
# Create simple model
model = torch.nn.Sequential(
*[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
input = torch.rand(1024, device="cuda")
# run forward pass
output = model(input)
# run backward to populate the grads for our optimizer below
output.sum().backward()
# No longer wrap the LR in a tensor here
opt = torch.optim.Adam(model.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
@torch.compile(fullgraph=False)
def fn():
opt.step()
sched.step()
# Setup logging to view recompiles
torch._logging.set_logs(recompiles=True)
# Warmup runs to compile the function
# We will now recompile on each iteration
# as the value of the lr is mutated.
for _ in range(2):
fn()
# Sample Output:
#
# >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
# >> triggered by the following guard failure(s):
# >> - L['self'].param_groups[0]['lr'] == 0.003333333333333333
In the above example, you can see in the sample output that the optimizer is recompiled on the iteration when the LR has been changed.
Implementation
To rectify this, support was added to permit the LR to be a tensor. torch.compile
treats tensors as graph inputs and does not guard on the data inside them, so this rectifies the recompile problem above. The LR schedulers were modified to mutate this tensor or python scalar to support both cases. As a result, if users are using the LR scheduler with the compiled optimizer, the LR should be wrapped in a tensor to avoid these recompiles. The final result is shown in the following example.
opt = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)
@torch.compile(fullgraph=False)
def fn():
opt.step()
sched.step()
# Setup logging to view recompiles
torch._logging.set_logs(recompiles=True)
# Warmup runs to compile the function
for _ in range(5):
fn()
print(opt.param_groups[0]["lr"])
# Sample Output:
#
# >> tensor(0.0047)
# >> tensor(0.0060)
# >> tensor(0.0073)
# >> tensor(0.0087)
# >> tensor(0.0100)
From the sample output we can see that the recompiles no longer appear. For a full runnable script see the tutorial.
Next Steps
In the future, it should be possible to do this LR tensor wrapping automatically in the compiler, but this would not solve the problem if the user wraps the optimizer with the LR scheduler after compiling.