Hi there, I would like to trace the backward graph, which has multiple outputs (stage_output) and inputs (input_values), and some of outputs require grad (specified by outputs_with_grads_idxs, which looks like below and contains backward() operation.
def stage_backward(stage_output, output_grads, input_values, outputs_with_grads_idxs: List[int]):
# some preprocessing code
torch.autograd.backward(
stage_output_tensors, # outputs that need backward, i.e. stage_output_tensors = stage_output[outputs_with_grad_idxs]
grad_tensors=output_grad_tensors
)
(check full code here).
I tried to trace into functions like
def stateless_backward(params, buffers, activations, kwargs_for_stage_backward):
func_out= stage_backward(**kwargs_for_stage_backward)
grads = {k: v.grad for k, v in params.items()}
return func_out, grads
gm = make_fx(stateless_backward, 'fake')(*args)
where params and buffers are FakeTensors saved from forward tracing (something like make_fx(stateless_forward, 'fake')), and activations are FakeTensors collected by iterating _saved_xxx of grad_fn for all stage_output.
The problem is that the resulting graph always contains _tensor_constant, and I suspect that I missed some tensors in my stateless_backward function arguments so they got cloned and saved in traced code.
I’m a beginner with the stateless graph and FakeTensor so I’m even not sure if it is a not-even-wrong intention. Please share if you have any idea :), any suggestion will be appreciated!