Haha, AOTAutograd used to be a lot simpler. And conceptually, it is not so complicated.
Basically, the pseudocode for how it works is:
Say we have our forwards
def f(*inputs):
return outputs
And say that we can call trace(f)(*inputs)
to get our forwards graph.
In order to get the backwards pass, we trace something like
def joint_fw_bw(fw_inputs, grad_outs):
fw_out = f(*fw_inputs)
grad_inps = torch.autograd.grad(fw_out, leaves=fw_inputs, gradOuts = grad_outs)
return fw_out, grad_inps
Then, we simply partition this graph into two to give us the forwards pass and the backwards pass.