Background
Forward Over Reverse–Definition
Forward over reverse (meaning jvp(vjp(f))
) computation is preferred for hessian computations (second derivatives). It is out of scope for this document to talk about why this is the preferred way to take a second derivative.
Sufficient for this is to say that (1) this is preferred for users from a speed perspective, (2) this requires a rule for computing the jvp
(forwards mode derivative) of the backwards rule for a primal function.
Specifically, if I have a function foo
and its associated backwards, foo_backward
, to calculate the forwards over reverse hessian, I will need a foo_backward_jvp
, which computes the jvp of foo_backwards
. I can, of course, use a double backwards, meaning I would instead need foo_double_backward
but it tends to be slower.
History
Forward over reverse functions were highly requested by functorch users. We used AOTAutograd/PrimTorch decompositions (only when running a backward function with forward mode AD) to implement these faster than we could write the rules themselves.
In doing so, we found some issues with the values produced by the forward over reverse decompositions for norms (when compared to the double backwards, or jacrev(jacrev(f))
versions). This will go through how we fixed those and why I find them hacky.
Then, we had a discrepancy between functorch and PyTorch. Functorch would be able to support more forward over reverse functions than PyTorch. So, Jeffrey moved the mechanism into core, and we ended up having these functions in core tagged as @register_decomposition_for_jvp
that were a holdover from functorch.
What are they? What’s the difference from the regular decompositions?
These decompositions (1) only exist for layer and batch norm (group norm has a forward over reverse function explicitly implemented) and (2) only have a small amount of change. When autograd is looking up the available decompositions to use, it will start by looking at the jvp decomposition table and then look at the standard decomposition table. So, this doesn’t override the normal decomposition table and is only used by autograd.
Both norm functions take in the mean and var computed during the forward pass. The only change in the decomposition from the one used by other systems is to call the function recompute_mean_var
and use the mean and var output from that function in the rest of the computation. The code for recompute_mean_var is:
def recompute_mean_var(input, rstd, inner_dim_indices, keepdim):
mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
eps = torch.pow(1 / rstd, 2) - var
eps = eps.detach()
rstd = 1 / torch.sqrt(var + eps)
return mean, rstd
So this just recomputes the mean and variance…in fact, the mean and variance should be exactly the same (numerically) as the mean and variance passed to the backward function. The fact that we need to run this recompute (including recomputing epsilon since it’s not passed to the backwards) is definitely hacky and gross.
The only difference is that the original mean and var passed to the backward function are not part of any autograd graph, neither the forward nor the backward one. This new one will have a dependency on input, which at this point should be a dual number (participating in forward mode AD).
FAQ
Why doesn’t double backwards have this same problem?
We’ve implemented a custom function for layer norm and batch norm’s double backwards. The implementers of this functions know that there should be a dependency between input and the mean and variance. So, they write the formulas to respect this dependency and we don’t rely on the autograd engine to propagate this dependency. With the decompositions, we need the autograd engine to be the one to keep track of this dependency
Would it work to decompose the norm functions themselves (instead of the backwards functions)?
Yes. However, we would need to decompose the norm functions regardless. When we’re building the autograd graph, we currently don’t have a way of knowing if forward mode AD will be later run on this autograd graph.
Right now, we’re able to only decompose the backward function if forward mode AD is going to be run on that function. Since the norms are composite explicit autograd functions, any kernels that backends have written for batch norm or layer norm (or the backwards of these functions) will still be used otherwise. This means that we save performance in eager cases where we only use the forward computation or only need one layer of autograd. However this may be a solution with lowerings to compilers.
Why do we detach mean and var from the autograd graph? Shouldn’t they have these dependencies already?
Honestly I’m not certain of the history behind this but I have a guess. Mean and var are saved for backwards compute, but are never seen by the user. In other words, the composite explicit versions (native_batch_norm, native_layer_norm) that save the computed mean and variance are private functions in native_functions. Their public counterparts (batch_norm, layer_norm) are composite implicit functions that just call the composite explicit versions but don’t return mean nor variance
Because of this, a user is never going to be able to pass a grad_input for mean or variance. So, mean and variance are just used to compute the correct values for backwards and double backwards. Because the functions are written to know the relationship between input, mean, and variance, it’s not necessary to track them in autograd. Also, it’s expensive since it would involve building a part of the autograd graph that we won’t use.