Functorch: Levels as dynamically allocated classes
Abstract: functorch does not need to manually implement levels and internally dispatch to the appropriate functorch transformation for a level; the existing types dispatch in __torch_dispatch__
is sufficient, so long as every distinct level of transformation is given a fresh, dynamically allocated class associated with it. The tree structure of a class hierarchy also prevents level confusion (when levels are expressed as integers) when objects with levels escape their lexical scope.
Background: subclass dispatch in __torch_function__
At first glance, __torch_function__
(and __torch_dispatch__
, which is directly implemented using the same logic) is simply a mechanism for subclasses of Tensor to override the meaning of functions in the global namespace, where traditional dynamic dispatch on method invocation is insufficient. However, __torch_function__
also comes with a mechanism for handling multiple dispatch situations when tensors of two distinct subclasses are passed to the same operation. The behavior is perhaps best described with a set of examples:
class Parent:
@staticmethod
def __torch_function__(...): ...
class Child(Parent):
@staticmethod
def __torch_function__(...): ...
class Unrelated:
@staticmethod
def __torch_function__(...): ...
Consider:
-
torch.add(Parent(), Parent())
- this is the easy case, we simply invokeParent.__torch_function__
-
torch.add(Parent(), Child())
- here, it is clear that Child is “aware of” Parent, as it is a subclass of Parent. So the correct function to invoke isChild.__torch_function__
, with the assumption that Child knows how to handle Parent arguments. Algorithmically, this is implemented by preferring subclasses over parent classes. -
torch.add(Parent(), Unrelated())
- here, Parent and Unrelated have no subclass relationship, and there is not any reason to presuppose Parent or Unrelated knows how to deal with the other. So we pick an arbitrary function, e.g.,Parent.__torch_function__
, and if it returnsNotImplemented
(indicating that it doesn’t understand some of its arguments), we try the other function, e.g.,Unrelated.__torch_function__
. In fact, this process would also occur if the child class reported it didn’t know how to handle the parent class, although this situation is more unlikely.
To summarize, multiple dispatch is implemented by trying __torch_function__
on each argument, one by one, until we find one that actually works for all the argument types in questions. However, we’ll always try subclass implementations first, since they “know of” their parent classes and are more specific.
Background: level dispatch in functorch
Functorch transformations can be thought of as implemented by wrapping tensors with special wrappers, e.g., BatchTensor, which reinterpret the meaning of Tensor operations before forwarding on to their inner wrapped tensors. However, because functorch transformations can be composed with themselves, and because the order the transformations are composed matters, these wrappers must record the level associated with their computation. Take the following vmap example:
vmap(xs, lambda x: vmap(ys, lambda y: x + y))
Suppose that xs is tensor of size (10, 2) and ys is a tensor of size (15, 2); the output size of this operation is, perhaps surprisingly, (10, 15, 2)! This can be understood by considering the types of x (tensor of size 2) and y (also tensor of size 2), and then the effect each vmap has on them. At the time we perform the addition, x + y is of logical size (2). The inner vmap, vectorizes this into a computation on ys’s batch dimension giving the size (15, 2). Then the outer vmap vectorizes the computation once again onto xs’s batch dimension, giving the final physical size (10, 15, 2).
In reality, x and y are batch tensors whose physical representations are (10, 2) and (15, 2). Without any other intervention, adding these two tensors would be a shape error. Levels help us understand that we should expand x to (10, 1, 2) and y to (1, 15, 2) prior to performing the (broadcasting) addition.
Integer levels can be problematic if they are allowed to escape from the lexical scope they are defined in:
box = [None]
def write(x):
box[0] = x
return x
def read(y):
return box[0] + y
xs = torch.randn(10, 2)
ys = torch.randn(20, 10, 2)
print(vmap(xs, write))
print(vmap(ys, read))
What should the semantics of this program be? The batch tensor x escapes from the first vmap and shows up in the second vmap. It is clearly unrelated to the batch tensor y (whose batch size is 20, compared to x’s 10). But in both cases x and y are simply recorded as BatchTensor with level 0, and vmap cannot distinguish the two distinct vmap calls and will attempt to add (10, 2) and (20, 10, 2) directly, resulting in an error.
Dynamically allocated classes
Instead of recording integer levels, nested invocations of functorch transformations like vmap should instead generate fresh classes to wrap their arguments. Consider again the first vmap example:
vmap(xs, lambda x: vmap(ys, lambda y: x + y))
These two vmap invocations will result in the following dynamically generated class hierarchy (the names I have given are for purely expository reasons; in reality these classes are unnamed):
class BatchTensor0: # from outer vmap
wrapped: Tensor
@staticmethod
def __torch_dispatch__(...): ...
class BatchTensor1(BatchTensor0): # from inner vmap
wrapped: Union[Tensor, BatchTensor0]
@staticmethod
def __torch_dispatch__(...): ...
x will be a BatchTensor0, while y will be a BatchTensor1. We can now apply the multiple dispatch mechanism from __torch_function__
to determine which implementation of __torch_dispatch__
will be invoked in x + y. y is a subclass of x, and thus we will invoke BatchTensor1 (the subclass).
The key ingredient to this encoding is the subclass relation between BatchTensor1 and BatchTensor0. Without it, the multi dispatch is free to pick whichever class it wants to invoke __torch_dispatch__
on. With the subclass relation, it is obligated to handle BatchTensor1 first, which will do some expanding, unwrap its BatchTensor1 arguments and then redispatch (heading to BatchTensor0). If no BatchTensor1s are involved in an expression, we will short-circuit BatchTensor1 and immediately dispatch to BatchTensor0. Level corresponds to the depth of a class in the dynamically allocated class hierarchy.
We can also inspect what occurs in the lexical escape case:
class BatchTensor0: # from first vmap
wrapped: Tensor
@staticmethod
def __torch_dispatch__(...): ...
class BatchTensor0_1: # from second vmap
wrapped: Tensor
@staticmethod
def __torch_dispatch__(...): ...
Clearly these classes are unrelated. We can detect this case simply testing that each argument is a superclass of the current class whose __torch_dispatch__
is executing. As there’s no class relation here, this will (rightly) cause an error.
The presence of these dynamically allocated classes also suggests a clear API for imperative-style vmap:
with vmap() as BatchTensor0, vmap() as BatchTensor1:
x = BatchTensor0(xs)
y = BatchTensor1(ys)
return x + y
What’s next?
There are two benefits to representing functorch levels in this way:
- It removes the necessity for a separate functorch dispatcher; indeed, there is no need for a functorch dispatcher at all after this change;
__torch_dispatch__
is all you need - It clearly explicates the semantics of non-lexical levels
There are some performance questions regarding the cost of dynamically creating classes in Python, and whether or not this design is compatible with compressed wrappers (e.g., BatchTensor can support arbitrarily many vmap levels, so long as they are contiguously leveled). However, this seems to be a clear win in terms of semantics, and we should ensure forward looking work obey these semantics.
Appendix: Python mode, no_dispatch
We need a sane API for no_dispatch
in order to implement functorch’s nested grad transform. Functorch’s grad transform is a hybrid between a mode-dispatch key and a “Variable” wrapper: all new Tensors created need to be wrapped in these Variables to prevent funny business around Variables at different levels interacting. no_dispatch
is necessary to prevent a mode-key from infinite recursion.
Concretely, no_dispatch should accept a tensor subclass and prevent dispatches for that tensor subclass.
Appendix: User education on type testing
Right now, most implementations of torch function/dispatch don’t do anything with the passed types array. This is bad: they should be testing if the types are what they expect (and if they are not, returning NotImplemented). This is going to be… hard to enforce.
Appendix: What happens if a tensor escapes the context manager?
We’ve solved the problem of levels being lexical (the level, which is a class, is invalid outside of the context manager! Compare this to before, when levels were just integers and were at danger of being re-used again).
Now, what happens if a tensor escapes the context manager? (what does the following return)?
def foo(xs):
with vmap() as BatchTensor0:
x = BatchTensor0(xs)
return x
Ideally upon exiting the context manager, x
just becomes a normal tensor. Unfortunately we can’t make this happen because x is literally an instance of BatchedTensor0 and we can’t change that in-place (unless we pursue some weak map idea).
There are a few possibilities here:
-
x
could become a “passthrough” BatchedTensor0 whose__torch_dispatch__
no longer calls batching rules and instead just invokes regular PyTorch operations. This can be done by directly mutating BatchedTensor0, replacing its torch dispatch with a new passthrough one. -
x
could become a “dead” BatchedTensor0 that just yells whenever it gets used. If this is the case the user should really returnx.elem()
.