Abstract: functorch does not need to manually implement levels and internally dispatch to the appropriate functorch transformation for a level; the existing types dispatch in
__torch_dispatch__ is sufficient, so long as every distinct level of transformation is given a fresh, dynamically allocated class associated with it. The tree structure of a class hierarchy also prevents level confusion (when levels are expressed as integers) when objects with levels escape their lexical scope.
At first glance,
__torch_dispatch__, which is directly implemented using the same logic) is simply a mechanism for subclasses of Tensor to override the meaning of functions in the global namespace, where traditional dynamic dispatch on method invocation is insufficient. However,
__torch_function__ also comes with a mechanism for handling multiple dispatch situations when tensors of two distinct subclasses are passed to the same operation. The behavior is perhaps best described with a set of examples:
class Parent: @staticmethod def __torch_function__(...): ... class Child(Parent): @staticmethod def __torch_function__(...): ... class Unrelated: @staticmethod def __torch_function__(...): ...
torch.add(Parent(), Parent())- this is the easy case, we simply invoke
torch.add(Parent(), Child())- here, it is clear that Child is “aware of” Parent, as it is a subclass of Parent. So the correct function to invoke is
Child.__torch_function__, with the assumption that Child knows how to handle Parent arguments. Algorithmically, this is implemented by preferring subclasses over parent classes.
torch.add(Parent(), Unrelated())- here, Parent and Unrelated have no subclass relationship, and there is not any reason to presuppose Parent or Unrelated knows how to deal with the other. So we pick an arbitrary function, e.g.,
Parent.__torch_function__, and if it returns
NotImplemented(indicating that it doesn’t understand some of its arguments), we try the other function, e.g.,
Unrelated.__torch_function__. In fact, this process would also occur if the child class reported it didn’t know how to handle the parent class, although this situation is more unlikely.
To summarize, multiple dispatch is implemented by trying
__torch_function__ on each argument, one by one, until we find one that actually works for all the argument types in questions. However, we’ll always try subclass implementations first, since they “know of” their parent classes and are more specific.
Functorch transformations can be thought of as implemented by wrapping tensors with special wrappers, e.g., BatchTensor, which reinterpret the meaning of Tensor operations before forwarding on to their inner wrapped tensors. However, because functorch transformations can be composed with themselves, and because the order the transformations are composed matters, these wrappers must record the level associated with their computation. Take the following vmap example:
vmap(xs, lambda x: vmap(ys, lambda y: x + y))
Suppose that xs is tensor of size (10, 2) and ys is a tensor of size (15, 2); the output size of this operation is, perhaps surprisingly, (10, 15, 2)! This can be understood by considering the types of x (tensor of size 2) and y (also tensor of size 2), and then the effect each vmap has on them. At the time we perform the addition, x + y is of logical size (2). The inner vmap, vectorizes this into a computation on ys’s batch dimension giving the size (15, 2). Then the outer vmap vectorizes the computation once again onto xs’s batch dimension, giving the final physical size (10, 15, 2).
In reality, x and y are batch tensors whose physical representations are (10, 2) and (15, 2). Without any other intervention, adding these two tensors would be a shape error. Levels help us understand that we should expand x to (10, 1, 2) and y to (1, 15, 2) prior to performing the (broadcasting) addition.
Integer levels can be problematic if they are allowed to escape from the lexical scope they are defined in:
box = [None] def write(x): box = x return x def read(y): return box + y xs = torch.randn(10, 2) ys = torch.randn(20, 10, 2) print(vmap(xs, write)) print(vmap(ys, read))
What should the semantics of this program be? The batch tensor x escapes from the first vmap and shows up in the second vmap. It is clearly unrelated to the batch tensor y (whose batch size is 20, compared to x’s 10). But in both cases x and y are simply recorded as BatchTensor with level 0, and vmap cannot distinguish the two distinct vmap calls and will attempt to add (10, 2) and (20, 10, 2) directly, resulting in an error.
Instead of recording integer levels, nested invocations of functorch transformations like vmap should instead generate fresh classes to wrap their arguments. Consider again the first vmap example:
vmap(xs, lambda x: vmap(ys, lambda y: x + y))
These two vmap invocations will result in the following dynamically generated class hierarchy (the names I have given are for purely expository reasons; in reality these classes are unnamed):
class BatchTensor0: # from outer vmap wrapped: Tensor @staticmethod def __torch_dispatch__(...): ... class BatchTensor1(BatchTensor0): # from inner vmap wrapped: Union[Tensor, BatchTensor0] @staticmethod def __torch_dispatch__(...): ...
x will be a BatchTensor0, while y will be a BatchTensor1. We can now apply the multiple dispatch mechanism from
__torch_function__ to determine which implementation of
__torch_dispatch__ will be invoked in x + y. y is a subclass of x, and thus we will invoke BatchTensor1 (the subclass).
The key ingredient to this encoding is the subclass relation between BatchTensor1 and BatchTensor0. Without it, the multi dispatch is free to pick whichever class it wants to invoke
__torch_dispatch__ on. With the subclass relation, it is obligated to handle BatchTensor1 first, which will do some expanding, unwrap its BatchTensor1 arguments and then redispatch (heading to BatchTensor0). If no BatchTensor1s are involved in an expression, we will short-circuit BatchTensor1 and immediately dispatch to BatchTensor0. Level corresponds to the depth of a class in the dynamically allocated class hierarchy.
We can also inspect what occurs in the lexical escape case:
class BatchTensor0: # from first vmap wrapped: Tensor @staticmethod def __torch_dispatch__(...): ... class BatchTensor0_1: # from second vmap wrapped: Tensor @staticmethod def __torch_dispatch__(...): ...
Clearly these classes are unrelated. We can detect this case simply testing that each argument is a superclass of the current class whose
__torch_dispatch__ is executing. As there’s no class relation here, this will (rightly) cause an error.
The presence of these dynamically allocated classes also suggests a clear API for imperative-style vmap:
with vmap() as BatchTensor0, vmap() as BatchTensor1: x = BatchTensor0(xs) y = BatchTensor1(ys) return x + y
There are two benefits to representing functorch levels in this way:
- It removes the necessity for a separate functorch dispatcher; indeed, there is no need for a functorch dispatcher at all after this change;
__torch_dispatch__is all you need
- It clearly explicates the semantics of non-lexical levels
There are some performance questions regarding the cost of dynamically creating classes in Python, and whether or not this design is compatible with compressed wrappers (e.g., BatchTensor can support arbitrarily many vmap levels, so long as they are contiguously leveled). However, this seems to be a clear win in terms of semantics, and we should ensure forward looking work obey these semantics.
We need a sane API for
no_dispatch in order to implement functorch’s nested grad transform. Functorch’s grad transform is a hybrid between a mode-dispatch key and a “Variable” wrapper: all new Tensors created need to be wrapped in these Variables to prevent funny business around Variables at different levels interacting.
no_dispatch is necessary to prevent a mode-key from infinite recursion.
Concretely, no_dispatch should accept a tensor subclass and prevent dispatches for that tensor subclass.
Right now, most implementations of torch function/dispatch don’t do anything with the passed types array. This is bad: they should be testing if the types are what they expect (and if they are not, returning NotImplemented). This is going to be… hard to enforce.
We’ve solved the problem of levels being lexical (the level, which is a class, is invalid outside of the context manager! Compare this to before, when levels were just integers and were at danger of being re-used again).
Now, what happens if a tensor escapes the context manager? (what does the following return)?
def foo(xs): with vmap() as BatchTensor0: x = BatchTensor0(xs) return x
Ideally upon exiting the context manager,
x just becomes a normal tensor. Unfortunately we can’t make this happen because x is literally an instance of BatchedTensor0 and we can’t change that in-place (unless we pursue some weak map idea).
There are a few possibilities here:
xcould become a “passthrough” BatchedTensor0 whose
__torch_dispatch__no longer calls batching rules and instead just invokes regular PyTorch operations. This can be done by directly mutating BatchedTensor0, replacing its torch dispatch with a new passthrough one.
xcould become a “dead” BatchedTensor0 that just yells whenever it gets used. If this is the case the user should really return