This post may be helpful: How to set wrap function using TorchDynamo graph capture?
Another way to do this as I know is like:
from functorch.compile import aot_module
captured_graphs = []
def custom_compiler(m: torch.fx.GraphModule, _):
captured_graphs.append(m)
return make_boxed_func(m.forward)
aot_model = aot_module(model, fw_compiler=custom_compiler)
y = aot_model(inputs)
y.sum().backward()