Torch.compile with AOT Autograd can be debugged now!

Example code:

from typing import List
import torch
from torch import _dynamo as torchdynamo
from depyf.explain.backend import eager, aot_eager

@torch.compile(backend=aot_eager)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

import depyf
with depyf.prepare_debug(toy_example, "./debug_dir"):
    for _ in range(100):
        input = torch.randn(10, requires_grad=True), torch.randn(10, requires_grad=True)
        toy_example(*input)
with depyf.debug():
    input = torch.randn(10, requires_grad=True), torch.randn(10, requires_grad=True)
    toy_example(*input).sum().backward()

Example debug screenshot:

breakpoint at the abs operation in the forward pass of the first subgraph:

breakpoint at the grad of abs operation (i.e. the sgn function) in the backward pass of the first subgraph:

For more details, please refer to the repo GitHub - thuml/depyf: Decompile python bytecode, debug and understand PyTorch compiler! .

Thanks for the helpful discussion from @Chillee and @bdhirsh in How does torch.compile work with autograd? .

3 Likes

This is very cool! :slight_smile:

2 Likes