How to get the backward graph while using torch.export?

I have read the references such as topic-1621 and aot_autograd. However, I am unable to find a solution that enables exporting the backward graph while using torch.export. Is there any trick to do this? I provide a simple code script below:

import torch
from torch._functorch.aot_autograd import aot_module

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        x = self.linear(x)
        return x

inp = torch.randn(2, 3, device="cuda", requires_grad=True)
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))

# check graph module
print(ep.module().code)

# Run it with torch.compile
compile_module = torch.compile(ep.module(), backend="inductor")
res = compile_module(inp)

res.sum().backward()
# how print the backward graph?

Thank you for your support!

I can see the backward graph in the debug logs after calling TORCH_COMPILE_DEBUG=1. However, can I extract this graph directly from APIs?

This post may be helpful: How to set wrap function using TorchDynamo graph capture?

Another way to do this as I know is like:

from functorch.compile import aot_module

captured_graphs = []

def custom_compiler(m: torch.fx.GraphModule, _):
    captured_graphs.append(m)
    return make_boxed_func(m.forward)

aot_model = aot_module(model, fw_compiler=custom_compiler)
y = aot_model(inputs)
y.sum().backward()

Thank for your reply. I have tried this before but failed with some models. I think aot module is less expressive than using torch.export directly.

It makes sense. Also wonder how to use export to do backward capture. :smiley: