This half, thanks to the contributions from a lot of people, autograd has seen numerous enhancements in terms of extensibility, flexibility, and debuggability.
In this post, we highlight a few features in particular more in depth: gradient edges, post accumulate grad hooks, foreach forward and backward AD support, logging for backward execution. We encourage checking release notes for a more exhaustive list of changes.
Compute gradient wrt gradient edges - @albanD available in versions >=2.2
-
torch.autograd.graph.get_gradient_edge
is an advanced API that allows one to compute gradients wrt to a tensor without necessarily keeping the tensor in memory. -
Calling get_gradient_edge returns a “pointer” to a particular version of a tensor which you can then pass to the inputs= argument of .grad just like you would with tensors. e.g.
torch.autograd.grad(y, inputs=(x,))
is equivalent to doingtorch.autograd.grad(y, inputs=(get_gradient_edge(x),))
-
This useful if you have a tensor that you plan to modify in-place later, but you want to compute gradients of the output with respect to the version of the tensor before it was modified. By using
get_gradient_edge(x)
, you can stash the pointer to the tensor before it was modified. This allows you to indicate to autograd APIs your desired inputs without having to save concrete tensors. -
Future work: can we pass GradientEdge object as the “output” to gradient.
Example:
# y = exp(sin(x)); I want derivative of y wrt sin(x).
x = torch.tensor(1., requires_grad=True).sin()
# stash a pointer pointing to x before doing the in-place
x_edge = torch.autograd.graph.get_gradient_edge(x)
x.exp_()
torch.autograd.grad(x, inputs=(x_edge,))
Output (with TORCH_LOGS="+autograd"
):
[2024-01-05 16:24:10,076] torch.autograd.graph: [DEBUG] Executing: <ExpBackward0 object at 0x147899670> with grad_outputs: [f32[]]
Post accumulate grad hooks - @janeyx99 available in versions >=2.1
-
torch.Tensor.register_post_accumulate_grad_hook
is a autograd backward hooks API that allows one to register a hook that fires after the grad has been accumulated onto the grad. See the here to see how it interacts with other hooks. -
This hook is an extension point to do any post processing necessary after gradient accumulation, e.g. this can be used to implement resharding/reduce_scatter in FSDP.
-
One use case in particular is that this hook can implement optimizers that avoid keeping all the .grads alive at once, which in terms of memory usage is equivalent to keeping another set of parameters alive! Traditionally, optimizers need to wait for backward to have completed in its entirety and for all .grads to be accumulated. Using this hook however, optimizers can eagerly perform the .step() as soon as a gradient is ready as the backward pass is performed, avoiding keeping them in memory. This in particular useful in low batch-size/seq-len regimes where the parameters rather than activation memory is the bottleneck. See this excellent tutorial for more details.
foreach forward and backward AD support - @crcrpar available in versions >= 2.1
-
We’ve made foreach operations differentiable for both forward and backward AD.
-
Some background on foreach: If you are doing the same operation on many tensors simultaneously, foreach operations allows one to potentially fuse those operations into a single kernel to speed up execution.
-
Since the optimizer step is precisely this type of operation, optimizer use foreach underneath, so being able to differentiate through foreach operations allows one to differentiate through optimizers, enabling more efficient higher-order optimization.
-
This change now also allows one to use foreach during forward in training. You might want to use foreach in training as opposed to something like a NestedTensor if you require more flexibility with the dimensions of your constituent tensors.
-
Please note that foreach operations are still a private API. Discussions are underway about making these public, however.
Example:
a = torch.rand(10, requires_grad=True)
b = torch.rand(10, requires_grad=True)
c = torch.rand(10, requires_grad=True)
d = [a, b, c]
out = torch._foreach_mul([a, b, c], 10)
sum(out).sum().backward()
Output (with TORCH_LOGS="+autograd"
):
[2024-01-05 16:41:08,273] torch.autograd.graph: [DEBUG] Executing: <SumBackward0 object at 0x100dd7f10> with grad_outputs: [f32[]]
[2024-01-05 16:41:08,274] torch.autograd.graph: [DEBUG] Executing: <AddBackward0 object at 0x100dd7f10> with grad_outputs: [f32[10]]
[2024-01-05 16:41:08,274] torch.autograd.graph: [DEBUG] Executing: <AddBackward0 object at 0x100dd7f10> with grad_outputs: [f32[10]]
[2024-01-05 16:41:08,274] torch.autograd.graph: [DEBUG] Executing: <AddBackward0 object at 0x100dd7f10> with grad_outputs: [f32[10]]
[2024-01-05 16:41:08,274] torch.autograd.graph: [DEBUG] Executing: <CppFunction object at 0x100dd7f10> with grad_outputs: [f32[10],f32[10],f32[10]]
[2024-01-05 16:41:08,274] torch.autograd.graph: [DEBUG] Executing: <AccumulateGrad object at 0x100dd7f10> with grad_outputs: [f32[10]]
[2024-01-05 16:41:08,274] torch.autograd.graph: [DEBUG] Executing: <AccumulateGrad object at 0x100dd7f10> with grad_outputs: [f32[10]]
[2024-01-05 16:41:08,274] torch.autograd.graph: [DEBUG] Executing: <AccumulateGrad object at 0x100dd7f10> with grad_outputs: [f32[10]]
Logging for backward execution available in versions >=2.3
-
You can now enable the logging for execution of the backward pass by setting the
TORCH_LOGS="+autograd"
environment variable, e.g.TORCH_LOGS="+autograd" python test.py
. You can also usetorch._logging.set_logs(autograd=logging.DEBUG)
if you wish to toggle this within code. -
This can be useful to debug changes where engine execution failed outside of a backward kernel e.g., the last frame in cpp stack trace when you specify TORCH_SHOW_CPP_STACKTRACES=1 is not incredibly helpful.
-
Note that a PR updating the format of the logging is still in the process of landing. So what you observe may be different from what is shown here. Update torch.autograd.graph logging to not print out grad_output by soulitzer · Pull Request #116523 · pytorch/pytorch · GitHub
-
Future work: As of today, we only log the current executing node and the information about the grad_outputs. But we’d appreciate any feedback on what else should be logged.
Example:
a = torch.rand(10, requires_grad=True)
b = a.mul(2).div(3).sum()
c = b.clone()
torch.autograd.backward((b, c))
Output (with TORCH_LOGS="+autograd"
):
[2024-01-05 14:09:19,866] torch.autograd.graph: [DEBUG] Executing: <CloneBackward0 object at 0x12367d7c0> with grad_outputs: [f32[]]
[2024-01-05 14:09:19,866] torch.autograd.graph: [DEBUG] Executing: <SumBackward0 object at 0x12367d7c0> with grad_outputs: [f32[]]
[2024-01-05 14:09:19,866] torch.autograd.graph: [DEBUG] Executing: <DivBackward0 object at 0x12367d7c0> with grad_outputs: [f32[10]]
[2024-01-05 14:09:19,867] torch.autograd.graph: [DEBUG] Executing: <MulBackward0 object at 0x12367d7c0> with grad_outputs: [f32[10]]
[2024-01-05 14:09:19,867] torch.autograd.graph: [DEBUG] Executing: <AccumulateGrad object at 0x12367d7c0> with grad_outputs: [f32[10]]