How does torch.compile work with autograd?

I used to think that graph capture and graph compile can be totally separated, and I can learn Dynamo and Inductor separatedly. That is true for forward computation, but it seems things become much more complicated when autograd comes into play.

  1. How do we deal with partial graph in aot autograd?

When graph break occurs, the forward graph is broken into several sub graphs. However, we have only one backward graph with one backward() call. Do we break the backward graph according to the forward graph, too?

  1. How does Dynamo deal with the backward call?

When the python code calls backward(), it does a lot of things. If the transformed bytecode also calls the backward() function, then it just invokes the autograd engine in eager mode. If Dynamo transforms backward call to invoking the backward graph computation manually, I suppose that would be quite a big function (with all the parameters as input, and it would be troublesome to collect necessary parameters). And we have to deal with backward() with parameters, like backward(gradient=None, retain_graph=None, create_graph=False, inputs=None).

  1. Are there any documentation about aot autograd?

After searching for a while, I just find some scattered documentation in functorch and about FakeTensor.

Might be relevant to @jansel @Chillee

1 Like

Hey @youkaichao - AOTAutograd is the major component that handles the backward when running torch.compile. It also handles other things like functionalization, tensor subclasses, tracing through other pytorch behavior implemented in the dispatcher (like functorch and AMP), and normalizing the graph from torch IR to ATen IR.

There isn’t super great AOTAutograd documentation (yet… stay tuned!)

The short answer is that AOTAutograd will set things up such that if your forward consists of 3 separate compiled graphs, then the backward will also consist of 3 backward graphs. There’s also an experimental “compiled autograd” mode that will try to generate a single backward graph, even if there are graph breaks in the forward.

In more detail:

How do we deal with partial graph in aot autograd. When graph break occurs, the forward graph is broken into several sub graphs. However, we have only one backward graph with one backward() call. Do we break the backward graph according to the forward graph, too?

Let’s say that I have a forward model like this:

@torch.compile
def f(x):
    tmp1 = x.sin().sin()
    print("graph break!")
    out = tmp1.sin().sin()
    return out

x = torch.ones(2, requires_grad=True)
out = f(x)

Our forward will consist of two forward graphs, that are each compiled by inductor separately (you can get these yourself by running the above with TORCH_LOGS="graph_code" python tmp.py). I printed the first of the two graphs below:

def forward(self, L_x_ : torch.Tensor):
    sin = L_x_.sin()
    tmp1 = sin.sin()
    return (tmp1,)

AOTAutograd will then take that first graph of torch ops, and ahead-of-time (while compiling the forward, before the backward has run), it will do the following:

(1) Run the above torch code with FakeTensors, tracing through the autograd engine (as well as all other pytorch functionalities implemented inside of the dispatcher), to generate a corresponding backward graph. The code where AOTAutograd traces the backward lives here

This will create a single graph containing the joint forward-backward. One thing to note: the original user function has a single input, but this joint graph has two inputs: a forward input, and a grad_output that is an input to the backward:

def forward(self, primals, tangents):
    primals_1: f32[2], tangents_1: f32[2], = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    # File: /data/users/hirsheybar/b/pytorch/tmp5.py:7, code: out = tmp1.sin().sin()
    sin: f32[2] = torch.ops.aten.sin.default(primals_1)
    sin_1: f32[2] = torch.ops.aten.sin.default(sin)
    cos: f32[2] = torch.ops.aten.cos.default(sin);  sin = None
    mul: f32[2] = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None
    cos_1: f32[2] = torch.ops.aten.cos.default(primals_1);  primals_1 = None
    mul_1: f32[2] = torch.ops.aten.mul.Tensor(mul, cos_1);  mul = cos_1 = None
    return pytree.tree_unflatten([sin_1, mul_1], self._out_spec)

(2) AOTAutograd will take this joint graph, and partition it into a separate forward and backward graph, that inductor will compile separately (partitioning code lives here and here)

Example partitioned graph (from running the code snippet above):

# forward
def forward(self, primals_1: f32[2]):
    # File: /data/users/hirsheybar/b/pytorch/tmp5.py:7, code: out = tmp1.sin().sin()
    sin: f32[2] = torch.ops.aten.sin.default(primals_1)
    sin_1: f32[2] = torch.ops.aten.sin.default(sin);  sin = None
    return [sin_1, primals_1]

# backward
def forward(self, primals_1: f32[2], tangents_1: f32[2]):
    # File: /data/users/hirsheybar/b/pytorch/tmp5.py:7, code: out = tmp1.sin().sin()
    sin: f32[2] = torch.ops.aten.sin.default(primals_1)
    cos: f32[2] = torch.ops.aten.cos.default(sin);  sin = None
    mul: f32[2] = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None
    cos_1: f32[2] = torch.ops.aten.cos.default(primals_1);  primals_1 = None
    mul_1: f32[2] = torch.ops.aten.mul.Tensor(mul, cos_1);  mul = cos_1 = None
    return [mul_1]

Also, you can see that “primals_1” is one of the inputs to the backward. One thing that the partitioning also handles is that tries to make some decisions around what should be saved for the backward vs. recomputed during the backward pass. In the above example, we had to save primals_1 for the backward, since the backward formula for x.sin() requires re-using x (the derivative of x.sin() is x.cos()).

(3) AOTAutograd will wrap these two compiled functions into a torch.autograd.Function object: it’s forward will run the compiled forward graph, and its backward will run the compiled backward graph (code here). This is how torch.compile effectively handles “partial graphs” that inter-operate with the rest of eager-mode pytorch: we have a bunch of compiled autograd.Function objects that are stitched together with the eager mode code.

You can see the generated joint graph and the partitioned forward and backward graphs by printing the logs with TORCH_LOGS="graph_code,aot" python tmp5.py (this will print both the torch IR graph, and the ATen IR graphs after AOTAutograd has run).

6 Likes

That’s a fantastic reply!

My take-home mesage is that forward graph is converted and lowered to aten ops with both forward and backward graph (in the form of a new autograd.Function).

Mentally it is clear. Though it looks intimidating to scan the aot_autograd.py with over 4K lines of code :frowning:

1 Like

Haha, AOTAutograd used to be a lot simpler. And conceptually, it is not so complicated.

Basically, the pseudocode for how it works is:

Say we have our forwards

def f(*inputs):
    return outputs

And say that we can call trace(f)(*inputs) to get our forwards graph.

In order to get the backwards pass, we trace something like

def joint_fw_bw(fw_inputs, grad_outs):
     fw_out = f(*fw_inputs)
     grad_inps = torch.autograd.grad(fw_out, leaves=fw_inputs, gradOuts = grad_outs)
     return fw_out, grad_inps

Then, we simply partition this graph into two to give us the forwards pass and the backwards pass.

3 Likes

This is a very neat mental model! The only remaining question is this produces a fixed computation graph, and I don’t know how aot autograd can improve the graph computation.

The only remaining question is this produces a fixed computation graph, and I don’t know how aot autograd can improve the graph computation.

What do you mean by this?

I mean how to determine which intermediate variables to save for the most efficient backward. Still trying to understand the post Min-cut optimal(*) recomputation (i.e. activation checkpointing) with AOTAutograd - #9 by Chillee .

Oh, this is not an issue, since we trace joint_fw_bw together, and then partition.

So, during partitioning we have full freedom to modify what gets saved in the forwards pass and what gets recomputed in the backwards pass.

Thanks for the super fast reply, I’ll try to understand the joint fw and bw graph first, I have difficulty understanding it.

Slides 7-16 here might help: AOTAutograd - Google Slides

1 Like

Thanks for the pointer, it really helps!

Plus: the example of sin in the slides seems to have no memory benefit I think. Another example you provide in Min-cut optimal(*) recomputation (i.e. activation checkpointing) with AOTAutograd did have memory benefit.

Plus: I’m confused by the word “partition”. By partition, we usually mean to partition the graph into disjoint subsets. In the following example, if we save add_2 for backward, both fwd and bwd have to compute the edge cos on add_2. When we partition the joint graph to get the bwd graph based on activations {x1, ..., xn}, the fwd graph should be the minimum graph that can compute {x1, ..., xn, output}. Those two graphs might have some overlap.

Final Plus: I used to think aot autograd can discover something like optimized sigmoid (e.g. users write eager code z = 1 / (1 + torch.exp(-x)), and we can figure out the smart backward as z * (1 - z)). Now that I understand what aot autograd can do.

the example of sin in the slides seems to have no memory benefit I think

Well, the example in the slides was just about explaining what the joint graph looks like - not about what an optimized version looks like.

I’m confused by the word “partition”. By partition, we usually mean to partition the graph into disjoint subsets.

Haha, this is true. In this case, the min-cut “partitioner” is not really partitioning the graph into two disjoint subsets - the backwards pass will be recomputing significant parts of the forwards pass. I still think of it as a “partitioning” problem because we’re given a graph with signature joint(fw_inputs, bw_inputs) => (fw_outputs, bw_outputs), and we need to return two graphs forward(fw_inputs) => (fw_inputs, activations) and backward(activations, bw_inputs) => bw_outputs.

I used to think aot autograd can discover something like optimized sigmoid (e.g. users write eager code z = 1 / (1 + torch.exp(-x)) , and we can figure out the smart backward as z * (1 - z) ). Now that I understand what aot autograd can do.

It’s possible we could make these kinds of decisions automatically in AOTAutograd (we have all the information), but I’m actually not totally sure this is even the right thing to do :stuck_out_tongue: In this case, we would just recompute sigmoid forwards in the backwards pass, which I think will be just as efficient.

1 Like

Agree with you on the sigmoid opinion. A pure machine learning researcher might be excited when he discovers the gradient formula z * (1 - z), but he cannot understand that the bottleneck is the memory access :smile:

1 Like

@Chillee I’m trying to understand what AOT autograd does to the computation 1.0 / (exp(-x) + 5), and I write down the manual joint graph as follows:

Red circles are what would be saved by the eager mode autograd engine.

I expect that AOT autograd can do some optimization, e.g. only save x2 for backward, and recompute x4 during backward, so that memory cost can be reduced. However, after running the AOT autograd engine, I find that both x2 and x4 are saved for backward.

Is it an expected case? How can I only save x2 if I’m striving for memory efficiency?

1 Like