Some background
The autograd codegen (that lives in tools/autograd/*
) is responsible for generating a lot of code. Even though some of this scope is not related to autograd (at all) and it is more located in this place for historical reasons.
Indeed, it used to be that ATen was a pure Tensor
library while most things happening on the frontend side where happening with Variable
s which were the thin wrappers handling autograd around Tensor
. And so the Variable
layer, implemented by the autograd codegen, was both handling autograd, translation from Variable to Tensor as well as binding the function to python.
Today, the codegen is responsible for:
- Generating the functions that get registered for the
AutogradFunctionality
and/orAutogradOther
keys (regular autograd code). - Generating the functions that get registered to the
ADImplaceOrView
key (handling of views and inplace in autograd). - Generating the functions that get registered to the
Tracer
key (handling of jit tracing). - Generating all the objects required for creating the autograd graph (Nodes)
- Generating the python binding for the
Tag
- Generating the python binding for all the functions from native_functions.yaml (in the top level namespace as well as
nn
,fft
,sparse
,special
andlinalg
- Generating the python binding for all the methods in native_functions.yaml
- Generating the python bindings for all the autograd graph Nodes
- Generating the python bindings for the NamedTuple that are returned by all native functions that return multiple outputs
- Generating the
torch::foo
version of all the factory functions (where autograd is handled). One could argue that this should happen the same way as others functions by going through the dispatcher and it most likely should.
The asserts
Given the above situation, the code generated by the autograd codegen used to be the main entry point for the cpp API and is still a kernel that is always hit today for “elementary aten functions” (not CompostiteImplicitAutograd).
Moreover, the autograd codegen has some restrictions on the functions that it can handle (must be always views, or always inplace, but not sometimes).
These two facts meant that this kernel is a good place to do sanity check on the behavior of functions and ensure that the implementation does follow the specification from the schema.
In particular, they check that:
- A function should never change the input’s TensorImpl (the intrusive_ptr shouldn’t be played with)
- A function should never change the input’s StorageImpl (unless it one of the few functions that is explictly here to do so). Note that inplace and out= ops can modify the content and/or size of that storage but should never change the object itself.
- An out-of-place function must return outputs whose TensorImpl use_count is at most 1 (a fresh Tensor is being returned).
- For view operations, make sure that Tensors that should be views do share the same Storage.
- For non-view, out-of-place operation, the outputs storage must have a use_count of 1.
Some limitations from native functions
There are limitations to the check above. In particular, the view relationship is checked by making sure that the storage is properly shared between the two Tensors. So this check only runs when a given TensorImpl has a storage associated with it. Most notably, TensorImpl like SparseTensorImpl
do not have these.
Some of the use_count()
check on Tensor can also be problematic in cases where the function returns the same Tensor multiple times or the function is internally more complex and has global state.
These are quite rare and we have small allow lists of 15 functions that opt out from a subset of these tests for these reasons.
Interaction with TensorImpl that don’t store data directly
A good example of a TensorImpl that interacts well with these assert is the FunctionalTensorWrapper
TensorImpl
. This Tensor is used as a wrapper around the “real” Tensor to be able to properly transform all the inplace/view ops into out-of-place operations without changing the semantic of the program or the user code.
In particular, it uses the storage_
field to track views and when Tensors within this view have been changed. This allows to lazily apply mutation that happened on other view onto this one.
An important side effect of this is that all the FunctionalTensorWrapper
that are views of each other are actually sharing the same Storage (as it is the common structure used to track views).
This means that all the asserts above will work as expected and will properly ensure that this TensorImpl doesn’t do wrong view tracking for any op automatically.
Interaction with Tensor subclasses
Direct subclass (or “is-a” subclass)
As a reminder, we call direct subclass a Tensor subclass created with _make_subclass
and for which self
is an actual Tensor on which we run backend implementation (usually by calling the super()
implementation of torch dispatch).
Such a Tensor will have all its metadata properly set all the time as they are required when we actually execute the Tensor’s code.
Such a subclass should interact just fine with the view asserts (or at least as well as the plain Tensor they subclass).
The only assert that can be problematic is the one that ensures that the returned Tensor use count is 1. Indeed, if the user saves the returned Tensor in any global state on the python side, then this assert will fail while it is ok (even though discouraged in most cases) to do so.
Wrapper subclass (or “has-a” subclass)
As a reminder, we call wrapper subclasses the ones created with _make_wrapper_subclass
. For these classes, self
has all the properties of a Tensor except that it doesn’t have any storage. There is usually a field on the python object of the subclass that store the actual data required by the subclass (if any).
These sub-classes are more problematic as:
- their metadata can get out of date and there is no direct way to set new values for them as of now beyond running the metadata mutating function on
self
via asuper()
call and assume that this function doesn’t touch the data_ptr. We should be able to use Tags to mark such functions and thus allow the user to easily get this working. - they have a proper storage with all the right properties except that the raw pointer is always
nullptr
. This means that the user needs to properly propagate this Storage when doing view ops to ensure no problem happens. This can be done easily via thet.set_(storage)
method on Tensor.
Interaction with Tensors on the “meta” device
These Tensors are a bit special as their implementation has evolved quite a bit. But as of writing, “meta” is a full fledged device, with a custom allocator and everything. So Tensors on this device should behave the same as any other.
Unfortunately, the current meta implementations for most view kernels do NOT properly share storage between views (via set_ or otherwise) and thus they do fail the view asserts. We should be able to easily fix that by making sure to properly share storage between input/output Tensors for view operations.
Proposed fixes
I think the following fixes need to happen:
- Fix “meta” implementation to properly implement views (by sharing storage since they have one today) and ensure that
test_meta.py
passes the asserts. - Provide any necessary python API to allow subclass writers to be able to report the view behavior.
- Re-enable the asserts for subclasses and fix the tests with the proposed workarounds above (partially solves test_python_dispatch fails on DEBUG=1 · Issue #78519 · pytorch/pytorch · GitHub):
- all the direct subclass should be relatively simple to fix
- all the wrapper subclass will most likely need extra logic to handle views properly. This is as simple as setting the proper storage on them by using
set_
. We can make that choice either by looking at the backing element and see if they are views or by using tags on the func themselves.
- Move these asserts from internal asserts to TORCH_CHECK as they now can be reached by end-users of Tensor subclass and should provide better error message.
- Fix our CI to actually run DEBUG builds so that we do get signal from these asserts: Debug job does not build in debug mode · Issue #78634 · pytorch/pytorch · GitHub