The goal of these notes is going to be to dive into the different set of hooks that we have in pytorch and how they’re implemented (with a specific focus on autograd and torch.nn hooks).
This first part is an exhaustive (to the best of my knowledge) list of hooks that you can find in pytorch.
The next part will be diving into more details for each of these and explain how they’re implemented.
From the python API
Note that typing hints below are not the actual ones, they are just here to give a simple idea of what is expected and returned by each hook function.
Autograd hooks
- Tensor gradient hooks via
Tensor.register_hook(fn: Callable[Tensor, Optional[Tensor]])
- The given function is called every time a gradient for this Tensor is computed.
- These hooks can optionally return a new value for the gradient that will be used in the autograd instead of the current value.
- Autograd Node gradient hooks via
Node.register_hook(fn: Callable[Tuple[Tensor, ...], Tuple[Tensor, ...]])
- The given function is called every time this node is executed and return the gradient wrt the inputs.
- These hooks must always return the new gradient values.
Torch.nn hooks
Hooks for a given Module:
- Module pre-forward hook via
Module.register_forward_pre_hook(fn: Callable[Tuple[Module, Any, ...], Optional[Tuple[Any, ...]]])
- Can be used to get the input value just before the evaluation of the
Module.forward
method.
- Can be used to get the input value just before the evaluation of the
- Module forward hook via
Module.register_forward_hook(fn: Callable[Tuple[Module, Any, ...], Optional[Tuple[Any, ...]]])
- Can be used to get the input value just after the evaluation of the
Module.forward
method.
- Can be used to get the input value just after the evaluation of the
- Module backward hook via
Module.register_full_backward_hook(fn: Callable[Tuple[Module, Tuple[Any, ...], Tuple[Any, ...]], Optional[Tuple[Any, ...]]])
- Can be used to get the value of the gradients wrt to all inputs and output of the Module.
- The old
Module.register_backward_hook()
that are deprecated and will be removed soon ™
Global hooks (same as above but affect every Module that runs)
- Module pre-forward hook via
nn.modules.module.register_forward_pre_hook(fn: Callable[Tuple[Module, Any, ...], Optional[Tuple[Any, ...]]])
- Can be used to get the input value just before the evaluation of the
Module.forward
method for every Module that runs.
- Can be used to get the input value just before the evaluation of the
- Module forward hook via
nn.modules.module.register_forward_hook(fn: Callable[Tuple[Module, Any, ...], Optional[Tuple[Any, ...]]])
- Can be used to get the input value just after the evaluation of the
Module.forward
method for every Module that runs.
- Can be used to get the input value just after the evaluation of the
- Module backward hook via
nn.modules.module.register_full_backward_hook(fn: Callable[Tuple[Module, Tuple[Any, ...], Tuple[Any, ...]], Optional[Tuple[Any, ...]]])
- Can be used to get the value of the gradients wrt to all inputs and output of the Module for every Module that runs.
- Again here there are the old
register_backward_hook()
that don’t work and are deprecated
Torch.package
My understanding of these hooks is that they can be used to alter the behavior of the packager with respect to how dependencies are captured (included in the package, included as dependency, mocked).
-
register_extern_hook(hook: Callable[Tuple[PackageExporter, str], None])
The hook will be called each time a module matches against anextern()
pattern. -
register_intern_hook(hook: Callable[Tuple[PackageExporter, str], None])
The hook will be called each time a module matches against anintern()
pattern. -
register_mock_hook(hook: Callable[Tuple[PackageExporter, str], None])
The hook will be called each time a module matches against amock()
pattern.
Distributed hooks
-
DistributedDataParallel.register_comm_hook(state: object, hook: Callable[Tuple[object, GradBucket], Future])
allows the user to alter how the gradients are accumulated in DDP to allow experimenting with more complex algorithms.
From the c++ API
Autograd hooks
- Hook on a Tensor that is called whenever its gradient is computed
Tensor::register_hook(std::function<Tensor(const Tensor&)> fn)
- Hook on a Node that is called just before the
Node::apply
is calledvoid add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook)
- Note that pre_hooks cannot be removed
- Hook on a Node that is called just after the
Node::apply
is calleduintptr_t add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook)
- Hook in the autograd Engine to run a special function once the backward pass that is currently running is done
void queue_callback(std::function<void()> callback)
(can only be called from within the backward pass) - Hook on the grads captured in the engine (when we run autograd.grad()) that is calls every time such capture is populated:
capture.hooks_.append(GradCaptureHook instance)
Module hooks
- None, feature request is Forward/backward hooks for C++ torch::nn modules · Issue #25888 · pytorch/pytorch · GitHub