How to get the backward graph while using torch.export?

This post may be helpful: How to set wrap function using TorchDynamo graph capture?

Another way to do this as I know is like:

from functorch.compile import aot_module

captured_graphs = []

def custom_compiler(m: torch.fx.GraphModule, _):
    captured_graphs.append(m)
    return make_boxed_func(m.forward)

aot_model = aot_module(model, fw_compiler=custom_compiler)
y = aot_model(inputs)
y.sum().backward()