I have read the references such as topic-1621 and aot_autograd. However, I am unable to find a solution that enables exporting the backward graph while using torch.export. Is there any trick to do this? I provide a simple code script below:
import torch
from torch._functorch.aot_autograd import aot_module
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.linear(x)
return x
inp = torch.randn(2, 3, device="cuda", requires_grad=True)
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))
# check graph module
print(ep.module().code)
# Run it with torch.compile
compile_module = torch.compile(ep.module(), backend="inductor")
res = compile_module(inp)
res.sum().backward()
# how print the backward graph?
Thank you for your support!