Hi there, I would like to trace the backward graph, which has multiple outputs (stage_output
) and inputs (input_values
), and some of outputs require grad (specified by outputs_with_grads_idxs
, which looks like below and contains backward()
operation.
def stage_backward(stage_output, output_grads, input_values, outputs_with_grads_idxs: List[int]):
# some preprocessing code
torch.autograd.backward(
stage_output_tensors, # outputs that need backward, i.e. stage_output_tensors = stage_output[outputs_with_grad_idxs]
grad_tensors=output_grad_tensors
)
(check full code here).
I tried to trace into functions like
def stateless_backward(params, buffers, activations, kwargs_for_stage_backward):
func_out= stage_backward(**kwargs_for_stage_backward)
grads = {k: v.grad for k, v in params.items()}
return func_out, grads
gm = make_fx(stateless_backward, 'fake')(*args)
where params and buffers are FakeTensor
s saved from forward tracing (something like make_fx(stateless_forward, 'fake')
), and activations
are FakeTensor
s collected by iterating _saved_xxx
of grad_fn
for all stage_output
.
The problem is that the resulting graph always contains _tensor_constant
, and I suspect that I missed some tensors in my stateless_backward
function arguments so they got cloned and saved in traced code.
I’m a beginner with the stateless graph and FakeTensor
so I’m even not sure if it is a not-even-wrong intention. Please share if you have any idea :), any suggestion will be appreciated!