Hi there, I would like to trace the backward graph, which has multiple outputs (`stage_output`

) and inputs (`input_values`

), and some of outputs require grad (specified by `outputs_with_grads_idxs`

, which looks like below and contains `backward()`

operation.

```
def stage_backward(stage_output, output_grads, input_values, outputs_with_grads_idxs: List[int]):
# some preprocessing code
torch.autograd.backward(
stage_output_tensors, # outputs that need backward, i.e. stage_output_tensors = stage_output[outputs_with_grad_idxs]
grad_tensors=output_grad_tensors
)
```

(check full code here).

I tried to trace into functions like

```
def stateless_backward(params, buffers, activations, kwargs_for_stage_backward):
func_out= stage_backward(**kwargs_for_stage_backward)
grads = {k: v.grad for k, v in params.items()}
return func_out, grads
gm = make_fx(stateless_backward, 'fake')(*args)
```

where params and buffers are `FakeTensor`

s saved from forward tracing (something like `make_fx(stateless_forward, 'fake')`

), and `activations`

are `FakeTensor`

s collected by iterating `_saved_xxx`

of `grad_fn`

for all `stage_output`

.

The problem is that the resulting graph always contains `_tensor_constant`

, and I suspect that I missed some tensors in my `stateless_backward`

function arguments so they got cloned and saved in traced code.

I’m a beginner with the stateless graph and `FakeTensor`

so I’m even not sure if it is a not-even-wrong intention. Please share if you have any idea :), any suggestion will be appreciated!