How does torch.compile work with autograd?

Hey @youkaichao - AOTAutograd is the major component that handles the backward when running torch.compile. It also handles other things like functionalization, tensor subclasses, tracing through other pytorch behavior implemented in the dispatcher (like functorch and AMP), and normalizing the graph from torch IR to ATen IR.

There isn’t super great AOTAutograd documentation (yet… stay tuned!)

The short answer is that AOTAutograd will set things up such that if your forward consists of 3 separate compiled graphs, then the backward will also consist of 3 backward graphs. There’s also an experimental “compiled autograd” mode that will try to generate a single backward graph, even if there are graph breaks in the forward.

In more detail:

How do we deal with partial graph in aot autograd. When graph break occurs, the forward graph is broken into several sub graphs. However, we have only one backward graph with one backward() call. Do we break the backward graph according to the forward graph, too?

Let’s say that I have a forward model like this:

@torch.compile
def f(x):
    tmp1 = x.sin().sin()
    print("graph break!")
    out = tmp1.sin().sin()
    return out

x = torch.ones(2, requires_grad=True)
out = f(x)

Our forward will consist of two forward graphs, that are each compiled by inductor separately (you can get these yourself by running the above with TORCH_LOGS="graph_code" python tmp.py). I printed the first of the two graphs below:

def forward(self, L_x_ : torch.Tensor):
    sin = L_x_.sin()
    tmp1 = sin.sin()
    return (tmp1,)

AOTAutograd will then take that first graph of torch ops, and ahead-of-time (while compiling the forward, before the backward has run), it will do the following:

(1) Run the above torch code with FakeTensors, tracing through the autograd engine (as well as all other pytorch functionalities implemented inside of the dispatcher), to generate a corresponding backward graph. The code where AOTAutograd traces the backward lives here

This will create a single graph containing the joint forward-backward. One thing to note: the original user function has a single input, but this joint graph has two inputs: a forward input, and a grad_output that is an input to the backward:

def forward(self, primals, tangents):
    primals_1: f32[2], tangents_1: f32[2], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    # File: /data/users/hirsheybar/b/pytorch/tmp5.py:7, code: out = tmp1.sin().sin()
    sin: f32[2] = torch.ops.aten.sin.default(primals_1)
    sin_1: f32[2] = torch.ops.aten.sin.default(sin)
    cos: f32[2] = torch.ops.aten.cos.default(sin);  sin = None
    mul: f32[2] = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None
    cos_1: f32[2] = torch.ops.aten.cos.default(primals_1);  primals_1 = None
    mul_1: f32[2] = torch.ops.aten.mul.Tensor(mul, cos_1);  mul = cos_1 = None
    return pytree.tree_unflatten([sin_1, mul_1], self._out_spec)

(2) AOTAutograd will take this joint graph, and partition it into a separate forward and backward graph, that inductor will compile separately (partitioning code lives here and here)

Example partitioned graph (from running the code snippet above):

# forward
def forward(self, primals_1: f32[2]):
    # File: /data/users/hirsheybar/b/pytorch/tmp5.py:7, code: out = tmp1.sin().sin()
    sin: f32[2] = torch.ops.aten.sin.default(primals_1)
    sin_1: f32[2] = torch.ops.aten.sin.default(sin);  sin = None
    return [sin_1, primals_1]

# backward
def forward(self, primals_1: f32[2], tangents_1: f32[2]):
    # File: /data/users/hirsheybar/b/pytorch/tmp5.py:7, code: out = tmp1.sin().sin()
    sin: f32[2] = torch.ops.aten.sin.default(primals_1)
    cos: f32[2] = torch.ops.aten.cos.default(sin);  sin = None
    mul: f32[2] = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None
    cos_1: f32[2] = torch.ops.aten.cos.default(primals_1);  primals_1 = None
    mul_1: f32[2] = torch.ops.aten.mul.Tensor(mul, cos_1);  mul = cos_1 = None
    return [mul_1]

Also, you can see that “primals_1” is one of the inputs to the backward. One thing that the partitioning also handles is that tries to make some decisions around what should be saved for the backward vs. recomputed during the backward pass. In the above example, we had to save primals_1 for the backward, since the backward formula for x.sin() requires re-using x (the derivative of x.sin() is x.cos()).

(3) AOTAutograd will wrap these two compiled functions into a torch.autograd.Function object: it’s forward will run the compiled forward graph, and its backward will run the compiled backward graph (code here). This is how torch.compile effectively handles “partial graphs” that inter-operate with the rest of eager-mode pytorch: we have a bunch of compiled autograd.Function objects that are stitched together with the eager mode code.

You can see the generated joint graph and the partitioned forward and backward graphs by printing the logs with TORCH_LOGS="graph_code,aot" python tmp5.py (this will print both the torch IR graph, and the ATen IR graphs after AOTAutograd has run).

6 Likes