How to trace torch.autograd.backward or torch.autograd.grad?

Hi there, I would like to trace the backward graph, which has multiple outputs (stage_output) and inputs (input_values), and some of outputs require grad (specified by outputs_with_grads_idxs, which looks like below and contains backward() operation.

def stage_backward(stage_output, output_grads, input_values, outputs_with_grads_idxs: List[int]):
    # some preprocessing code
    torch.autograd.backward(
        stage_output_tensors, # outputs that need backward, i.e. stage_output_tensors = stage_output[outputs_with_grad_idxs]
        grad_tensors=output_grad_tensors
    )

(check full code here).

I tried to trace into functions like

def stateless_backward(params, buffers, activations, kwargs_for_stage_backward):
        func_out= stage_backward(**kwargs_for_stage_backward)
        grads = {k: v.grad for k, v in params.items()}
        return func_out, grads

gm = make_fx(stateless_backward, 'fake')(*args)

where params and buffers are FakeTensors saved from forward tracing (something like make_fx(stateless_forward, 'fake')), and activations are FakeTensors collected by iterating _saved_xxx of grad_fn for all stage_output.

The problem is that the resulting graph always contains _tensor_constant, and I suspect that I missed some tensors in my stateless_backward function arguments so they got cloned and saved in traced code.

I’m a beginner with the stateless graph and FakeTensor so I’m even not sure if it is a not-even-wrong intention. Please share if you have any idea :), any suggestion will be appreciated!

aot_export_module should provide you some help, otherwise you may need to use __torch_dispatch__ to achieve your requirements.

1 Like

These discussions might provide some help: How does torch.compile work with autograd? - #4 by Chillee

2 Likes

Thanks shuokay! The default_partition func is just what I need.

Btw do you have any idea if I only want to trace the backward graph (instead of tracing a joint one and partitioning it)?