How does torch.compile work with autograd?

Haha, AOTAutograd used to be a lot simpler. And conceptually, it is not so complicated.

Basically, the pseudocode for how it works is:

Say we have our forwards

def f(*inputs):
    return outputs

And say that we can call trace(f)(*inputs) to get our forwards graph.

In order to get the backwards pass, we trace something like

def joint_fw_bw(fw_inputs, grad_outs):
     fw_out = f(*fw_inputs)
     grad_inps = torch.autograd.grad(fw_out, leaves=fw_inputs, gradOuts = grad_outs)
     return fw_out, grad_inps

Then, we simply partition this graph into two to give us the forwards pass and the backwards pass.