Interested in how PyTorch’s autograd works conceptually?
Want to understand how TorchScript can fuse operations even when they are recording gradient?
I put together an executable notebook, Simple Grad, that walks through a pedagogical implementation of autograd that is very similar conceptually to the one PyTorch, but free of all the messy implementation details like defining gradients/autograd in C++. Spending time to understand how this example would should make it easier to read PyTorch’s actual autograd code an make changes to it.
Cool!
So one thing I’ve been wondering about is whether we should make this type of experiment easier by allowing programmatic access to derivatives.
The “easy” part might be exposing what is in derivatives.yaml / what is generated from it.
The hard part is what to do with the functions in ATen that have “tape-based differentiation”.
This would also enable us to have multiple implementations auf autograd (autograd classic, autodiff, “hacker’s autograd” (ie in Python)) use the same per-operator derivatives.
@ezyang and I (among others) have occasionally discussed this, but haven’t solidified on an action. Vaguely, we’ve wanted to rewrite the derivatives in derivatives.yaml to use Python syntax (e.g. each formula is a python function, returns its result and a closure for backward similar to this notebook). Then we would parse the AST and transliterate the formulas to C++ for ‘real’ autograd. Combined with a Python reference for the autograd engine, everything would be much more hackable.
But this requires quite a few steps:
translate derivatives.yaml to python format (automated, but run once with the results checked in) -> translated Python function AST format into C++, replacing functionality of derivatives.yaml (automated, run every build). Then write a reference autograd.
Note that while this is nice to read, it is dangerous when considering double backward.
In particular, it is very easy for the user to capture an intermediary result from the forward (that is not an output) and use it in the backward. Doing so would make the double backward wrong.
The current derivatives.yaml prevent this by design as only the input and outputs are available to the user to compute the backward!
Right! I don’t think checking this would be an insurmountable obstacle, though, so it might not be too bad. This trap is already set up with symbolic script, too.
I think the idea would be to look at the Python AST to sort out captures and make them into proper ‘SavedVariables’ when needed. This could be done both for the C++ transliteration and as a AST → AST transform done as an annotation for the Python stuff itself. Obviously it would need to be quite limited to keep it simple.
You need more than that! You need them to be output of the forward. Because we need a way to compute the gradients for them!
The users needs to tell us what we should do for the gradients coming for these intermediary!
In the jit autodiff, there is a lot of work to make that work by making these outputs of the Node and then hiding these extra outputs before returning to the user.