PyTorch hooks Part 1: All the available hooks

The goal of these notes is going to be to dive into the different set of hooks that we have in pytorch and how they’re implemented (with a specific focus on autograd and torch.nn hooks).

This first part is an exhaustive (to the best of my knowledge) list of hooks that you can find in pytorch.

The next part will be diving into more details for each of these and explain how they’re implemented.

From the python API

Note that typing hints below are not the actual ones, they are just here to give a simple idea of what is expected and returned by each hook function.

Autograd hooks

  • Tensor gradient hooks via Tensor.register_hook(fn: Callable[Tensor, Optional[Tensor]])
    • The given function is called every time a gradient for this Tensor is computed.
    • These hooks can optionally return a new value for the gradient that will be used in the autograd instead of the current value.
  • Autograd Node gradient hooks via Node.register_hook(fn: Callable[Tuple[Tensor, ...], Tuple[Tensor, ...]])
    • The given function is called every time this node is executed and return the gradient wrt the inputs.
    • These hooks must always return the new gradient values.

Torch.nn hooks

Hooks for a given Module:

  • Module pre-forward hook via Module.register_forward_pre_hook(fn: Callable[Tuple[Module, Any, ...], Optional[Tuple[Any, ...]]])
    • Can be used to get the input value just before the evaluation of the Module.forward method.
  • Module forward hook via Module.register_forward_hook(fn: Callable[Tuple[Module, Any, ...], Optional[Tuple[Any, ...]]])
    • Can be used to get the input value just after the evaluation of the Module.forward method.
  • Module backward hook via Module.register_full_backward_hook(fn: Callable[Tuple[Module, Tuple[Any, ...], Tuple[Any, ...]], Optional[Tuple[Any, ...]]])
    • Can be used to get the value of the gradients wrt to all inputs and output of the Module.
  • The old Module.register_backward_hook() that are deprecated and will be removed soon ™

Global hooks (same as above but affect every Module that runs)

  • Module pre-forward hook via nn.modules.module.register_forward_pre_hook(fn: Callable[Tuple[Module, Any, ...], Optional[Tuple[Any, ...]]])
    • Can be used to get the input value just before the evaluation of the Module.forward method for every Module that runs.
  • Module forward hook via nn.modules.module.register_forward_hook(fn: Callable[Tuple[Module, Any, ...], Optional[Tuple[Any, ...]]])
    • Can be used to get the input value just after the evaluation of the Module.forward method for every Module that runs.
  • Module backward hook via nn.modules.module.register_full_backward_hook(fn: Callable[Tuple[Module, Tuple[Any, ...], Tuple[Any, ...]], Optional[Tuple[Any, ...]]])
    • Can be used to get the value of the gradients wrt to all inputs and output of the Module for every Module that runs.
  • Again here there are the old register_backward_hook() that don’t work and are deprecated

Torch.package

My understanding of these hooks is that they can be used to alter the behavior of the packager with respect to how dependencies are captured (included in the package, included as dependency, mocked).

  • register_extern_hook(hook: Callable[Tuple[PackageExporter, str], None]) The hook will be called each time a module matches against an extern() pattern.
  • register_intern_hook(hook: Callable[Tuple[PackageExporter, str], None]) The hook will be called each time a module matches against an intern() pattern.
  • register_mock_hook(hook: Callable[Tuple[PackageExporter, str], None]) The hook will be called each time a module matches against a mock() pattern.

Distributed hooks

  • DistributedDataParallel.register_comm_hook(state: object, hook: Callable[Tuple[object, GradBucket], Future]) allows the user to alter how the gradients are accumulated in DDP to allow experimenting with more complex algorithms.

From the c++ API

Autograd hooks

  • Hook on a Tensor that is called whenever its gradient is computed Tensor::register_hook(std::function<Tensor(const Tensor&)> fn)
  • Hook on a Node that is called just before the Node::apply is called void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook)
    • Note that pre_hooks cannot be removed
  • Hook on a Node that is called just after the Node::apply is called uintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook)
  • Hook in the autograd Engine to run a special function once the backward pass that is currently running is done void queue_callback(std::function<void()> callback) (can only be called from within the backward pass)
  • Hook on the grads captured in the engine (when we run autograd.grad()) that is calls every time such capture is populated: capture.hooks_.append(GradCaptureHook instance)

Module hooks

More that I don’t know about?

5 Likes