Example code:
from typing import List
import torch
from torch import _dynamo as torchdynamo
from depyf.explain.backend import eager, aot_eager
@torch.compile(backend=aot_eager)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
import depyf
with depyf.prepare_debug(toy_example, "./debug_dir"):
for _ in range(100):
input = torch.randn(10, requires_grad=True), torch.randn(10, requires_grad=True)
toy_example(*input)
with depyf.debug():
input = torch.randn(10, requires_grad=True), torch.randn(10, requires_grad=True)
toy_example(*input).sum().backward()
Example debug screenshot:
breakpoint at the abs
operation in the forward pass of the first subgraph:
breakpoint at the grad of abs
operation (i.e. the sgn
function) in the backward pass of the first subgraph:
For more details, please refer to the repo GitHub - thuml/depyf: Decompile python bytecode, debug and understand PyTorch compiler! .
Thanks for the helpful discussion from @Chillee and @bdhirsh in How does torch.compile work with autograd? .