Compiled Optimizer w/ LR Scheduler Now Supported

Supporting LRScheduler with Compiled Optimizer

Background

Many users have attempted to use the LRScheduler with the compiled optimizer, only to run into issues with recompiling the optimizer. This was due to torch.compile guarding on the scalar value of the LR, so every time the LRScheduler changed the value of the LR this would trigger a recompile. This is best illustrated by the following example.

# Create simple model
model = torch.nn.Sequential(
    *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)]
)
input = torch.rand(1024, device="cuda")

# run forward pass
output = model(input)

# run backward to populate the grads for our optimizer below
output.sum().backward()

# No longer wrap the LR in a tensor here
opt = torch.optim.Adam(model.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)

@torch.compile(fullgraph=False)
def fn():
    opt.step()
    sched.step()

# Setup logging to view recompiles
torch._logging.set_logs(recompiles=True)

# Warmup runs to compile the function
# We will now recompile on each iteration
# as the value of the lr is mutated.
for _ in range(2):
    fn()

# Sample Output:
#
# >>[DEBUG]:Recompiling function step in /data/users/mlazos/pytorch/torch/optim/adam.py:191
# >>    triggered by the following guard failure(s):
# >>    - L['self'].param_groups[0]['lr'] == 0.003333333333333333

In the above example, you can see in the sample output that the optimizer is recompiled on the iteration when the LR has been changed.

Implementation

To rectify this, support was added to permit the LR to be a tensor. torch.compile treats tensors as graph inputs and does not guard on the data inside them, so this rectifies the recompile problem above. The LR schedulers were modified to mutate this tensor or python scalar to support both cases. As a result, if users are using the LR scheduler with the compiled optimizer, the LR should be wrapped in a tensor to avoid these recompiles. The final result is shown in the following example.

opt = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5)

@torch.compile(fullgraph=False)
def fn():
    opt.step()
    sched.step()

# Setup logging to view recompiles
torch._logging.set_logs(recompiles=True)

# Warmup runs to compile the function
for _ in range(5):
    fn()
    print(opt.param_groups[0]["lr"])




# Sample Output:
#
# >> tensor(0.0047)
# >> tensor(0.0060)
# >> tensor(0.0073)
# >> tensor(0.0087)
# >> tensor(0.0100)

From the sample output we can see that the recompiles no longer appear. For a full runnable script see the tutorial.

Next Steps

In the future, it should be possible to do this LR tensor wrapping automatically in the compiler, but this would not solve the problem if the user wraps the optimizer with the LR scheduler after compiling.

See Also

4 Likes

Hi @mlazos, in my case, I tried wrapping the LR as a tensor, and recompilations were reduced, but I observed worse performance. The latency increased from 10s to 11s. Could the wrapper be introducing any overhead? maybe more guards?

Could you share a repro? This is pretty unexpected but it could be possible for some optimizers if we end up generating more kernels.

Thanks for your reply. Since the codebase is quite large, creating a minimal repro isn’t easy.
The workload is fine-tuning the llama’s decoder layer one by one. I compile the fine-tuning func as below:

@torch.compile()
def opt_layer(layer, layer_inputs, layer_labels):
    for input, label in zip(layer_inputs, layer_labels):
        layer_output = layer(input)
        loss = mse(layer_output, label)
        loss.backward()
        optimizer.step()


for layer in llama_model.Layers:
    opt_layer(layer, layer_inputs, layer_labels)

Could you let me know how to check the number of generated kernels?