PyTorch hooks Part 2: nn.Module hooks

This section is going to present how the forward and backward hooks on Modules work.

General idea

All the hooks on Modules are made possible because, while the user implements the forward() function to specify what should happen when the module is evaluated, users need to use the __call__() method on Modules to evaluate it.
As of today, this indirection is necessary for both hooks and jit tracing to work properly.

During the __call__(), all the global Module hooks are also added to any hook present on the current Module on the fly to ensure that both global and local hooks are called. In particular, the call function follows the following logic:

  • Handle the pre-forward hooks, optionally updating the inputs
  • Capture the inputs to be used by the backward hooks (this generates a new set of inputs)
  • Call the user-defined forward (with extra tracing logic if we’re tracing)
  • Capture the outputs to be used by the backward hooks (this generates a new set of outputs)
  • Handle the forward hooks, optionally updating the outputs
  • Deprecated: handle the old style backward hooks

Forward hooks

Both the pre-forward and forward hooks work very similarly.
For each of these hooks, we simply call the given function with the current inputs (respectively outputs) and if the returned value from the hook is not None, update the current inputs (respectively outputs) with the returned value.

Backward hooks

The backward hooks are a bit more complex because they need to be able to capture the gradients for all the Tensor inputs and outputs and be consistently called.

In the usual case where both inputs and outputs contains at least one Tensor that require gradients and that the backward call from the user is called on a Tensor that depends on both the inputs and outputs, the idea is as follow.
We create one instance of BackwardHook that will be used to store the required state and will be responsible to call the user-defined hooks as soon as possible.
Then we capture both inputs and outputs using this class, and, during the backward pass, the capture based on the outputs will be able to access the grad outputs and save them while the capture based on the inputs will be able to access the grad inputs and call the user hooks with all the required informations.

The capture mechanism works as follow for a given input (that can be either the inputs or outputs of the forward method):

  • Find all the Tensors in the given input and put them in a separate tuple.
  • Use a custom Function to do an identity on this tuple of Tensor. This allows us to have a single Node in the autograd graph that will take all the gradients for the tensors in input at once.
  • Add a hook to that newly created Node to capture the gradients.
  • Put the Tensors outputted by the custom Function back into a new object that can be used as the new input

There are a couple special cases on top of this basic one, in particular:

  • If there are no Tensor that require gradients in the inputs or outputs to the Module, we still want the hooks to be called if the other has some (for example the very first layer of your neural net where the inputs don’t require gradients but the outputs do). To handle this, we detect when this happens and call the user hooks directly from the only hook that exists.
  • If the grad inputs hook is called before the grad outputs one is (for example because the user extract an intermediary result without making it an output and then call backward on that, to compute special terms in the loss) then we need the hook for the grad inputs to handle this case by filling the grad outputs with None to signify that no gradient were computed.
  • (TODO: this raises an error today) If the Module is evaluated in no_grad mode, then the custom Function we rely on will not be created here. The code needs to be able to silently be a no-op as no backward hook will ever be needed as we’re in no_grad mode.

Legacy backward hooks

These hooks are deprecated. The reason is that in many cases, they don’t return the advertised quantity.

Indeed, the way they work is by considering the outputs after the forward hook have been called.
It then searches for the first Tensor with an associated grad_fn on these outputs (opening one level of dictionaries for some reason).
It then add hooks to this grad_fn to capture its grad outputs and grad inputs and use these quantities to call the user hooks.

This means that the user hooks will get as grad outputs the gradient for the output of the given Node (that might not correspond to all outputs to the Module and non-Tensor outputs will be ignored in an unpredictable way). And it will get as grad inputs the gradient for the input of the given Node which might not be corresponding to the inputs of the Module at all (or only a subset of them).

This construct will work for Modules whose forward is composed of a single Node (like Conv2d) but will fail for all other (like Linear).