How to read the autograd codebase

How to read the autograd code in PyTorch

This document will try to give you a good idea of how to browse the autograd-related source in PyTorch The goal is to get you familiar with what the key pieces are, where they are located, and the order in which you should read them.
Warning - this is by no means trying to give a good example of how to do things but a current state. In retrospect, we could do quite a few things better.

Note that all the file locations and links are done at ddf2681. Things might have moved in master since this has been written. Let me know if this gets outdated and I can update the links. At the time of writing, the forward mode AD code is not complete, so you can ignore all reference to it for now. The profiler code will also be ignored in this discussion as it is fairly independent from the autograd and maintained by different people.

This post assumes that you are familiar with Automatic Differentiation and what it tries to do.
All the discussion here will be about backward mode AD (or back propagation). We do backward mode AD by creating a computational graph while the function is being evaluated (forward pass) and then executing this graph when the user asks for gradients (backward pass).

At a very high level, autograd needs to provide the following functionalities:

  • Build computational graph along with the forward
  • Evaluate that graph to do a backward
  • Provide high level APIs
  • Provide testing tools

The short version of the full description below is

  • Low level data structures that are the components of the computational graph (Node, SavedVariable, Tensor)
  • Wrappers for forward functions to build the graph
    • Manually in C++
    • Via codegen and derivatives.yaml
    • Via custom Function (in Python or C++)
    • Special care for view/inplace
  • An engine to execute the computational graph
  • Hook system to control execution of the graph
  • Python bindings for all c++ objects
  • A Python and a C++ API
  • A high level API to compute complex quantities in Python
  • Gradient checking utilities for testing
  • Anomaly detection utility for debugging and testing
  • Dispatcher and how autograd integrates into PyTorch

Basic Components of the Computational Graph

The Node class
The graph that we build is a DAG based on the Node class. You can read its description here.
You can also read the rest of the Node class implementation in csrc/autograd/function.{h,cpp} with special care to:

  • operator()() that is used to evaluate this Node
  • next_edges() that is used to find the child Nodes in the graph (correspond to the .next_functions attribute of the python binding)
  • sequence_nr() that is used to have a, thread local, universal ordering of the Node. This ordering is used to ensure determinism in the execution order of the engine (when only a single device is used).
  • pyobj() that is used to store the corresponding python wrapper
  • metadata() that is used to store user metadata associated with this Node, see here for the c++ metadata and here for the python metadata.
  • release_variables() that is used to release all the resources saved by that Node

Subclasses of Node
We can’t directly create instances of Node because it is abstract. All nodes in our DAG are actually instances of Node’s subclasses. These usually implement at least:

  • A custom constructor
  • An apply() function that is called when the Node is evaluated
  • release_variables() if this Node store any states at constructions

Since the backward often needs some values from the forward to perform its job, subclasses of Node also typically store extra fields of all the inputs/outputs from the forward computation that it needs to use to compute the backward pass, we have a a special class to save these (without creating ref cycles that would leak memory): SavedVariables.

Saved Variables
SavedVariable is implemented in csrc/autograd/saved_variable.{h,cpp}
There are only two functions there:

  • The constructor that saves the plain Tensor data, stores the fields of autograd metadata on the side (this breaks the cycle mentioned above) and save the current “version” of the Tensor.
  • The .unpack() function that can be used to get the original Tensor + autograd metadata back. It also checks that the “version” of the Tensor didn’t change to make sure that no in-place operation was done on it since it was saved.

What would happen if we didn’t have SavedVariables?
Note that the cycle mentioned above would appear because the Node saves the Tensor while the Tensor has a grad_fn field which link back to the Node. This class avoids this cycle by making sure that the saved Tensor is not the same as the one that has the grad_fn field.

EDIT: More work has happening in this class since then. In particular, there are now hooks that can be used to specify what should happen while packing/unpacking. See this doc for more details.

AutogradMeta
In addition to its grad_fn, the output tensor’s “entry point” to the backward graph we described above, every tensor also holds other metadata that are important to autograd. This metadata is stored in the AutogradMeta struct which is uniquely owned by each tensor. In addition to a reference to grad_fn_ (owning), AutogradMeta also stores: reference to grad_accumulator_ (weak), a reference to grad_ Tensor (owning), as well as some other fields.

As of today in PyTorch, all Tensors are autograd aware and can store such metadata as seen here. This used to be different and we had Variables that were the autograd-aware wrapper for Tensors. As you can see here, the two are exactly the same type now and you can read Variable in the autograd codebase as Tensor.
Also, we have a special kind of Node associated with every Tensor (in a lazy manner as seen here) that is called AccumulateGrad. These are the leaf Nodes in the DAG and will, when executed, accumulate gradients in the .grad field of the corresponding Tensor.

That is all!
Now it’s all about building this DAG properly and executing it.

Building the Graph

This section presents how we build the graph and all the subtleties to make sure it computes the right thing in all case (or raises an error if we cannot handle a particular case).
The next section will present how to execute this graph in details.

Manually

The most basic way to build the graph can be found here.
This function takes the inputs of the computation, the output of the computations as well as a lambda that creates the Node and builds the graph accordingly. In particular, it:

  • If no input require gradients, make all input requires_grad=false and return
  • Collect the next edges from the inputs to the function
  • Build the Node based on these next edges
  • For every output Tensor, make it point properly to the newly created Node

We have a (very) small number of Nodes that are created this way and are implemented in the csrc/autograd/functions/ folder. All of them are here because they require very manual tuning of what happens that does not fit with the more generic tools we provide to create the graph.
We will come back to most of these but a simple example is the Error Node (that fails in the forward) that will just raise an error if anyone tries to use it. You can see the definition in here and the apply function’s implementation in here.
Another example is the DelayedError one here that will raise an error if the backward of this Node tries to be called.

Codegen

The problem is that writing a new class for every single operator in PyTorch would be very tedious: introducing the codegen.

Indeed, we use tools/autograd/derivatives.yaml as a configuration file that specify the backward formulas for (almost) all our operators.

The codegen should:

  • Replace the manual code above to define the Node`
  • Add code to automatically create the graph along with the forward
  • Add code to make factory functions (like torch.tensor() or torch.rand()) autograd-aware

For this part, I would recommend to first readers to skip reading the python code altogether. Reading examples from the generated files will be enough to understand how things work. For a good example, you can look for the functions associated with abs and angle.
A reading guide for the python version would be another post altogether and left for future work.

The main entry point for the codegen is tools/autograd/gen_autograd.py that will generate all the necessary files in torch/csrc/autograd/generated/. In particular, it will use:

  • tools/autograd/gen_autograd_functions.py to
    • generate all the Nodes that are required in torch/csrc/autograd/generated/Functions.{h,cpp}. It will make sure for each of them that:
      • The .apply() method performs the computation specified in derivatives.yaml
      • All inputs that will be used in the backward computation are saved and the .release_variables() method will release everything that was saved
      • The .name() method returns the name of this Node
    • generate python bindings for these Nodes in torch/csrc/autograd/generated/python_functions.{h,cpp} to allow to be accessed from python easily (and have checks like isinstance work properly).
    • You should look for AbsBackward and AngleBackward for the two examples.
    • Note as well that helper functions like angle_backward() are implemented in torch/csrc/autograd/FunctionsManual.{h,cpp}.
  • tools/autograd/gen_variable_type.py to generate the code that wraps ALL the Tensor methods in PyTorch in torch/csrc/autograd/generated/VariableTypeEverything.cpp.
    • These wrappers are responsible for:
      • Checking if the output will require gradients
      • Creating an instance of the Node for that op if needed
      • Saving the relevant input/outputs
      • Building the graph accordingly
      • Setting up the view relationship
      • Perform in-place checks and version counter bumps
    • You can look for abs() and angle() in this file
    • Note that some functions have special wrappers that can be found in torch/csrc/autograd/VariableTypeManual.{h,cpp}. These are used to implement functions that are not supported by the codegen or do not compute derivatives (like detach).
  • tools/autograd/gen_trace_type.py that is very similar to variable type above but it generates a file called TraceTypeEverything.cpp and is only responsible to perform tracing if it is enabled.
  • tools/autograd/gen_variable_factories.py that generates ``torch/csrc/autograd/generated/variable_factories.h. This contains a wrapper for each factory function that makes sure that the requires_grad=argument is properly handled. You can checkzeros()` in there for an example.
  • tools/autograd/load_derivatives.py is a helper file to interpret the content of derivatives.yaml, nothing to read here.

Custom Functions

While the codegen allows us to generate all this code very easily, it is not a convenient tool to use for third party users that want to define new Nodes: introducing custom Functions.

These allow the user to specify for the forward and backward of a given operation and will do all the necessary work to build the graph and setup the proper Node.
A good introduction on how to use the python API for this is here: Extending PyTorch — PyTorch 1.10.0 documentation and the c++ API mimics it as much as possible.

The C++ implementation is the simplest and is done with:

  • You can find here the base class that users should extend and define the forward and backward for.
  • And here is the .apply() function called during the forward that is responsible for calling the user’s forward as well as setting up the graph.
  • The most important function here is the ._wrap_outputs() one here that is used both by the python and c++ API and is the one setting up the graph. You should take the time to read this function in detail to understand what it is doing.
  • This Node is of a special type defined here that is stores the necessary info as well as the AutogradContext, defined here, that allows to reproduce a similar API to the python one.
  • This Node’s .apply() method here is then responsible for calling the user defined backward as well as releasing the related resources

The Python implementation is more complex as most of the components are still implemented in c++ using CPython C API:

  • When reading this code, you can ignore all the code related to instantiating or calling the Function class instances. This is the old API for custom Function that will be removed soon.
  • The user facing python class can be found here. It defines the two functions that should be overwritten by the user.
  • You can check the _ContextMethodMixin here that defines all the functions available on the ctx for the python users.
  • The _HookMixin here allows the user to register hooks on these custom Functions (not sure if this one is actually documented).
  • The FunctionMeta here that handles the fallback to old style Functions
  • The _FunctionBase that is defined in c++ here it implements all the low level API associated with the ctx
    • You can see accessors for all the attributes saved on the c++ Node here
    • You can see its .apply() method here that is responsible for creating the ctx, calling the user specified forward and setting up the graph properly with the same ._wrap_outputs() function as above (be careful we have two functions with that name that call each other!)
    • The Node it adds in the graph is a PyNode defined here that has a special apply function here that is responsible for calling the python class’s backward via the .apply() of the python class defined here. The PyNode apply is also responsible for packing inputs into python objects and unpacking the results back into c++ objects.

Special care for view and in-place

PyTorch’s autograd allows the user to use views and in-place operations but special care needs to be taken to make these works.
The gist of what needs to be done here is:

  • Ensure that all differentiable views are properly tracked.
  • Ensure that in-place operation properly bump version counters.
  • Ensure that in-place operations properly do graph rewrites if needed.
  • Ensure that views get the right graph if their base or another view was modified in-place.

The view tracking is done via the as_view() function here. This function is called by all the codegen’ed and manual functions if needed. You can check the codegen for view() for example in VariableTypeEverything.cpp to see that call. You can also check the manual code for detach() in VariableTypeManual.cpp for an example of a non-differentiable view.
You can then read the make_variable_differentiable_view() in torch/csrc/autograd/variable.h to see how this is done: namely using a special type of autograd metadata that also store the view informations.
Note that custom Function do not do this! Indeed, the view info are tracked even in no_grad mode and so we can simply read the view info on the output fo the custom Function to know if it is a view. This won’t work if the user reads the raw content of the Tensor in the forward and create a new Tensor that alias it but we are happy with this limitation.

The version counter bump is done in the codegen as well via the .increment_version() function call. You can find it for example when looking for relu_() in VariableTypeEverything.cpp.
The custom functions also do this via the .mark_dirty() function on the context. It allows the user to tell the autograd what they modified in-place.

When an in-place operation happens, we need to rewrite the graph to make sure that all the views of that same memory see this change in their history. This is done by the .rebase_history() function that is called both by the codegen and the custom Functions if needed. The main entry point is here and calls into the c++ implementation here.
As you can see there, this rewrite is only needed if the Tensor being modified in-place is a view. In that case, the following happen:

  • Get the base of that view.
  • Generate a new CopySlices Node that wraps the Node corresponding to the in-place op
    • This special Node expects a gradient of the same size as the base
    • Slice that gradient to be the size of the view being worked on
    • Call the original Node corresponding to the in-place op
    • Embed this small gradient in the full gradient buffer that is the same size as the base
  • Set this CopySlices as the new grad_fn for the base → meaning that this grad_fn will now be used by all the views!
  • Trigger an update of the grad_fn for this view implemented here
    • If this Tensor is a view and has been modified in-place since last time we generated its grad_fn (checked via the “version”)
    • regenerate the graph between the base of this Tensor and this Tensor
    • Set the grad_fn as this newly generated graph (that now points properly to the CopySlices generated above)

Ensuring that all other views properly point to the right graph is done automatically via the special grad_fn call described above. Indeed, other views will have their “version” updated as well by the version counter bump and so it will trigger a recompute of their graph the first time someone tries to access it.

Executing the Graph: The Engine

The execution engine is quite complex but fairly well contained. It is completely implemented in torch/csrc/autograd/engine.{h,cpp} and torch/csrc/autograd/python_engine.{h,cpp}.

We will get back to the python version of the engine at the end of this section as it only does very minimal changes.

The Engine is a stateful class for which a single instance exist in the whole PyTorch process. It is defined here and is created along with the process.
The high level ideas for the engine are:

  • Given outputs and inputs of the computations, it should run all the graph to compute gradients for the inputs.
  • It runs the backward on the same device as the forward
  • It runs CPU backward on the thread that called the engine. All the GPU backward run on different worker threads.
  • The engine is re-entrant as it may be called from within a backward Node. It must have measures to avoid stack overflows when there are deep recursive calls like this.
  • The engine provides a callback system that allows the user to run a function once the execution of a given task (to be defined below) is finished.

The engine has a single public API to execute the graph: .execute().
The only other user-facing API is the one to register callbacks. We will ignore this one for now.

You can read the code in the following order:

  • Start reading the .execute() code here.
    • An important data structure there is the GraphTask defined here that will be used to store everything that we need corresponding to this particular invocation of the execute function. In particular, it store things like
      • What remains to be done
      • If an error happened and if it is done
      • execute’s arguments such as keep_graph and grad_mode
      • Buffers for the gradients currently flowing back
      • Execution info about what needs to be executed for this particular call to execute. Remember that the same graph can be used (concurrently!) by multiple execute call. So the execute call must never modify the graph.
      • The calling thread TLS to be able to propagate it to worker threads
      • CUDA stream infos
      • re-entrant info
      • cpu work queue
      • callbacks
    • You can read all the functions called there, in particular compute_dependencies that is responsible for checking what needs to run in the graph
    • Ignore execute_with_graph_task for now
    • Note that at the end of this function, accessing the future’s value will throw if a worker thread sets an exception on the future. This is how the exception stack traces are propagated from the worker thread back to the main thread. Also this wait is always a no-op as, by this time, the execution is finished.
  • Go into execute_with_graph_task now that can be found here
    • It will set the final part of the graph task
    • Then push onto the right ready queue the first task that needs to be done
    • And it will go back to work
      • If it is a thread that was not involved in autograd before, start processing the CPU ready queue via thread_main. This function will exit once nothing remains to be done for this graph task
      • If this was already a worker thread (meaning we’re doing re-entrant autograd)
        • If you’re not too deep, update the depth and go into thread_main
        • If you’re too deep and risk to stack overflow, create a new thread via add_thread_pool_task that will do the work for you.
  • You can now check thread_main here that contains the main evaluation loop. You can follow all the calls into evaluate_function to see how we prepare the inputs, call the Node hooks, detect/set errors, validate the shapes, populate the input buffers, etc
  • Once all the tasks are done and the future is marked as such, the post processing function here is called as well to call the callbacks and do the cuda stream postprocessing.
  • It is important to see as well that the start_device_thread function is called with an std::call_once and is implemented here. It is used to initialize the worker threads (lazily the first time a user calls into the engine) for the engine singleton. You can see that each worker thread runs thread_init implemented here and sets it’s device properly before calling thread_main with an empty graph_task, meaning that it will run until he receives a special kill task.
  • Similarly, the re-entrant threadpool has long running threads that will execute one graph_task at a time as shown here.

The python engine has only very limited differences with the main engine:

  • It has a special pair of thread_on_exception here and execute here to save the full python stack trace on the worker thread and restore it properly on the main thread.
  • Special logic in the worker thread to pre-initialize the python state here to make sure that GIL acquisition in the worker threads is fast.

Hook system

As you’ve seen in the engine code, it is properly calling both pre and post hooks that are registered on the Node.

You can see these being attributes of all the Nodes in torch/csrc/autograd/function.h and you can find their definition in torch/csrc/autograd/function_hook.{h,cpp}.

You can also find specialization of these hooks to be able to register arbitrary std::function in torch/csrc/autograd/cpp_hook.{h,cpp}.
And you can find specialization to be able to register python functions in torch/csrc/autograd/python_hook.{h,cpp}.

All of these should be self-contained and easy to read.

Python bindings for all c++ objects

All the files in torch/csrc/autograd/ that start with python_ are actually implementing python-only logic.
We’ve seen a couple above related to the engine and the hooks, the other ones mainly define the python bindings for all the objects we have in c++.

Variables (Tensors)
A representative example here is the manual binding we do for Variable, which are actually Tensors.
You can see the custom python object defined here.
The corresponding python type is defined here. The functions provided in this struct are important as they should follow the very strict rules for the CPython type API. In particular, THPVariable_pynew and THPVariable_dealloc are custom creators and destructors. And THPVariable_traverse and THPVariable_clear are used to ensure ref cycles can be cleared by python’s GC. You can check these implementations if you’re curious.
With all the special attributes and methods for it defined here. You can check above in this file for the implementation for all of these.
You can then see here that this is the type that is exported as _TensorBase to the python side and that is used by torch as the parent class for all the Tensor types.

Nodes
You can also find the python binding for the Node created for python custom Functions here.
The binding for all the other Nodes that only live in c++ is here. This one is a bit special because we create a different python class for every Node we define in c++.
The binding for the execution engine is here and is very simple as it only exposes two functions (the one to run backward and the one to add callbacks).

(continuing in the reply below because of the word limit)

8 Likes

A Python and a C++ API

The end -user API in both cases is composed of two functions:

  • backward() that runs the whole graph and accumulate gradient in all leaf’s .grad fields
  • grad() that returns the gradients for the specified inputs

The end-user c++ API is implemented in torch/csrc/autograd/autograd.{h,cpp}.
It is a very thin layer above the engine execute function. You can read these two files in details.

Note as well that the Tensor.backward() function in c++ is implemented in VariableTypeManual.cpp and is calling into this user-facing API directly.

The end-user python API is implemented in torch/autograd/init.py and should be fairly straightforward as well as it only does error checking and then calls into the c++ implementation.
It calls into the run_backward() function here that is unpacking the python objects, releasing the GIL and then calling into the engine.

A high level API to compute complex quantities in Python

The functional API is a self contained API that builds on top of the basic API mentioned in the previous section.
It is implemented in torch/autograd/functional.py.

You should read the functions there in the following order:

  • vjp
  • jacobian
  • vhp
  • hessian
  • jvp
  • hvp

Gradient checking utilities for testing

In python, we provide some gradient checking utilities to make sure that the gradients computed by the autograd are correct.
There are two functions for this gradcheck() and gradgradcheck() that can be found in torch/autograd/gradcheck.py.

You should read gradcheck first and follow each function call in that file.
This function does:

  • validation of the inputs
  • make sure that if no output require gradients, the numerical gradients are all 0
  • make sure that the analytical (computed via autograd) gradients match the numerical gradients. With special handling for the complex case
  • make sure the backward pass is not ignoring the grad_output that is given to it by passing a 0 grad_output and making sure that all gradients are 0
  • make sure that the gradient formula handle undefined Tensors properly (we use undefined Tensors to represent 0 gradients)

Then gradgradcheck is doing the same thing but for second order gradients.

Anomaly detection utility for testing

Anomaly detection’s goal is to help user debug errors that happen in the backward pass.
In particular, when bad things happen in the backward, since it runs (almost) completely in c++, it can be hard to know what is happening from the python side.
To this end, the anomaly detection allows two things:

  • detect nan in the gradient as soon as they appear so that you know directly which Node created the nan values first.
  • when an error happen while executing a given Node in the backward, print also the stack trace of the forward when this Node was created.

The anomaly mode is enabled from python via a context manager defined in torch/autograd/anomaly_mode.py.
It is enabled from c++ via a similar RAII guard defined here.
Both of these set a global flag that is defined here and here.

The nan detection is implemented in the engine here.

For the enhanced stack traces, it is a bit more complex:

  • It hooks into every Node creation here and store the current stack trace (in a slightly different way if it is in c++ or python, see each implementation for details) as well as its parent (in case of re-entrant autograd, which Node triggered the autograd call that lead to the creation of this new Node).
  • When an error happens during the execution of the engine, print the stack trace of the forward (and all its parents) here.

That should give a good overview of where all the main pieces are. There are still some part that are missing like the role of the dispatcher or the new inference mode.
Please leave any comment you have or question below!

6 Likes