Why does JVP have its own decompositions?

Background

Forward Over Reverse–Definition

Forward over reverse (meaning jvp(vjp(f)) ) computation is preferred for hessian computations (second derivatives). It is out of scope for this document to talk about why this is the preferred way to take a second derivative.

Sufficient for this is to say that (1) this is preferred for users from a speed perspective, (2) this requires a rule for computing the jvp (forwards mode derivative) of the backwards rule for a primal function.

Specifically, if I have a function foo and its associated backwards, foo_backward , to calculate the forwards over reverse hessian, I will need a foo_backward_jvp , which computes the jvp of foo_backwards . I can, of course, use a double backwards, meaning I would instead need foo_double_backward but it tends to be slower.

History

Forward over reverse functions were highly requested by functorch users. We used AOTAutograd/PrimTorch decompositions (only when running a backward function with forward mode AD) to implement these faster than we could write the rules themselves.

In doing so, we found some issues with the values produced by the forward over reverse decompositions for norms (when compared to the double backwards, or jacrev(jacrev(f)) versions). This will go through how we fixed those and why I find them hacky.

Then, we had a discrepancy between functorch and PyTorch. Functorch would be able to support more forward over reverse functions than PyTorch. So, Jeffrey moved the mechanism into core, and we ended up having these functions in core tagged as @register_decomposition_for_jvp that were a holdover from functorch.

What are they? What’s the difference from the regular decompositions?

These decompositions (1) only exist for layer and batch norm (group norm has a forward over reverse function explicitly implemented) and (2) only have a small amount of change. When autograd is looking up the available decompositions to use, it will start by looking at the jvp decomposition table and then look at the standard decomposition table. So, this doesn’t override the normal decomposition table and is only used by autograd.

Both norm functions take in the mean and var computed during the forward pass. The only change in the decomposition from the one used by other systems is to call the function recompute_mean_var and use the mean and var output from that function in the rest of the computation. The code for recompute_mean_var is:

def recompute_mean_var(input, rstd, inner_dim_indices, keepdim):
mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)

eps = torch.pow(1 / rstd, 2) - var
eps = eps.detach()

rstd = 1 / torch.sqrt(var + eps)
return mean, rstd

So this just recomputes the mean and variance…in fact, the mean and variance should be exactly the same (numerically) as the mean and variance passed to the backward function. The fact that we need to run this recompute (including recomputing epsilon since it’s not passed to the backwards) is definitely hacky and gross.

The only difference is that the original mean and var passed to the backward function are not part of any autograd graph, neither the forward nor the backward one. This new one will have a dependency on input, which at this point should be a dual number (participating in forward mode AD).

FAQ

Why doesn’t double backwards have this same problem?

We’ve implemented a custom function for layer norm and batch norm’s double backwards. The implementers of this functions know that there should be a dependency between input and the mean and variance. So, they write the formulas to respect this dependency and we don’t rely on the autograd engine to propagate this dependency. With the decompositions, we need the autograd engine to be the one to keep track of this dependency

Would it work to decompose the norm functions themselves (instead of the backwards functions)?

Yes. However, we would need to decompose the norm functions regardless. When we’re building the autograd graph, we currently don’t have a way of knowing if forward mode AD will be later run on this autograd graph.

Right now, we’re able to only decompose the backward function if forward mode AD is going to be run on that function. Since the norms are composite explicit autograd functions, any kernels that backends have written for batch norm or layer norm (or the backwards of these functions) will still be used otherwise. This means that we save performance in eager cases where we only use the forward computation or only need one layer of autograd. However this may be a solution with lowerings to compilers.

Why do we detach mean and var from the autograd graph? Shouldn’t they have these dependencies already?

Honestly I’m not certain of the history behind this but I have a guess. Mean and var are saved for backwards compute, but are never seen by the user. In other words, the composite explicit versions (native_batch_norm, native_layer_norm) that save the computed mean and variance are private functions in native_functions. Their public counterparts (batch_norm, layer_norm) are composite implicit functions that just call the composite explicit versions but don’t return mean nor variance

Because of this, a user is never going to be able to pass a grad_input for mean or variance. So, mean and variance are just used to compute the correct values for backwards and double backwards. Because the functions are written to know the relationship between input, mean, and variance, it’s not necessary to track them in autograd. Also, it’s expensive since it would involve building a part of the autograd graph that we won’t use.

2 Likes

By the way, I have a PR fixing this, but I didn’t plan on landing it yet - there’s more context there on the possible ways forward. Update batch norm to compute forward grads for saved_mean and saved_var when input requires grad by soulitzer · Pull Request #81293 · pytorch/pytorch · GitHub

TLDR:

  • the PR seems to be just replacing a hack with another hack, because now we lie about an output’s backward differentiability. I can be convinced to land the PR if we don’t consider native_batch_norm a public API, but there may be better things we can do (that require more work though):
  • (1) Either we make codegen changes (and change how we think about forward/backward differentiability - allow an output to be forward differentiable, but not backward differentiable), or (2) rewrite the double backward formula (seems annoying, but could be worth a look)

What I am missing from this post is, suppose you were designing the system in a vacuum, what should it look like? What if efficiency didn’t matter? And then how would you reintroduce the optimization under controleld conditions? Under what situations would you have to do this to other functions besides these norms?

IMO, if we have a compiler with a good-enough CSE, these optimisations should just go in the original decompositoins and should be decided by the compiler that figures out the rematerialisation and so on. This would also get rid with these JVP-only decompositons. There are also similar VJP-only decompositions in core, and these should also go in this ideal world. There’s a paragraph about that here:

Definitely agree with Mario but just for clarity, I would want everything (including the batch_norm primal function) to be written as a decomposition so that I can have guarantees that the outputs to any function are differentiable. As Mario calls out in his proposal, this is true for pre/post-conditions too and I think that if I were designing this in a vacuum, I would want to replace all of these norms with prim decompositions that are differentiable, composite compliant, the whole shebang

This is the part that’s much more vague and hand wavey but I would want to reintroduce the norm kernels under conditions where we know that we can’t hurt ourselves. Specifically, if none of the inputs require grad, happy to use the original kernels

The thing that becomes weird with the forward over reverse decompositions is that we want to reintroduce the original kernels “if we will never do forward over reverse over the call” which we don’t know when we’re doing the original call (we could run forward mode over the held onto autograd graph much later on in the program)

Also like Mario pointed out and is alluded to in the FAQ, if we’re in a compiler stack my stance is that we should just let the compiler deal with it

This may happen with any aten function that returns intermediates and hides them from users (by only exposing functions that don’t return those intermediates). I think it’s less likely to happen if the function returns these intermediates since those could be used in later computation and the user would see it as a bug that they don’t propagate gradients

I think the norms are the main example of this, though it may come up again with some of the stuff that Richard’s thinking through with custom vjp especially since functorch is always thinking about higher order gradients and that’s really where we see these problems crop up