Functorch: Levels as dynamically allocated classes

Functorch: Levels as dynamically allocated classes

Abstract: functorch does not need to manually implement levels and internally dispatch to the appropriate functorch transformation for a level; the existing types dispatch in __torch_dispatch__ is sufficient, so long as every distinct level of transformation is given a fresh, dynamically allocated class associated with it. The tree structure of a class hierarchy also prevents level confusion (when levels are expressed as integers) when objects with levels escape their lexical scope.

Background: subclass dispatch in __torch_function__

At first glance, __torch_function__ (and __torch_dispatch__, which is directly implemented using the same logic) is simply a mechanism for subclasses of Tensor to override the meaning of functions in the global namespace, where traditional dynamic dispatch on method invocation is insufficient. However, __torch_function__ also comes with a mechanism for handling multiple dispatch situations when tensors of two distinct subclasses are passed to the same operation. The behavior is perhaps best described with a set of examples:

class Parent:
    @staticmethod
    def __torch_function__(...): ...

class Child(Parent):
    @staticmethod
    def __torch_function__(...): ...
    
class Unrelated:
    @staticmethod
    def __torch_function__(...): ...

Consider:

  • torch.add(Parent(), Parent()) - this is the easy case, we simply invoke Parent.__torch_function__
  • torch.add(Parent(), Child()) - here, it is clear that Child is “aware of” Parent, as it is a subclass of Parent. So the correct function to invoke is Child.__torch_function__, with the assumption that Child knows how to handle Parent arguments. Algorithmically, this is implemented by preferring subclasses over parent classes.
  • torch.add(Parent(), Unrelated()) - here, Parent and Unrelated have no subclass relationship, and there is not any reason to presuppose Parent or Unrelated knows how to deal with the other. So we pick an arbitrary function, e.g., Parent.__torch_function__, and if it returns NotImplemented (indicating that it doesn’t understand some of its arguments), we try the other function, e.g., Unrelated.__torch_function__. In fact, this process would also occur if the child class reported it didn’t know how to handle the parent class, although this situation is more unlikely.

To summarize, multiple dispatch is implemented by trying __torch_function__ on each argument, one by one, until we find one that actually works for all the argument types in questions. However, we’ll always try subclass implementations first, since they “know of” their parent classes and are more specific.

Background: level dispatch in functorch

Functorch transformations can be thought of as implemented by wrapping tensors with special wrappers, e.g., BatchTensor, which reinterpret the meaning of Tensor operations before forwarding on to their inner wrapped tensors. However, because functorch transformations can be composed with themselves, and because the order the transformations are composed matters, these wrappers must record the level associated with their computation. Take the following vmap example:

vmap(xs, lambda x: vmap(ys, lambda y: x + y))

Suppose that xs is tensor of size (10, 2) and ys is a tensor of size (15, 2); the output size of this operation is, perhaps surprisingly, (10, 15, 2)! This can be understood by considering the types of x (tensor of size 2) and y (also tensor of size 2), and then the effect each vmap has on them. At the time we perform the addition, x + y is of logical size (2). The inner vmap, vectorizes this into a computation on ys’s batch dimension giving the size (15, 2). Then the outer vmap vectorizes the computation once again onto xs’s batch dimension, giving the final physical size (10, 15, 2).

In reality, x and y are batch tensors whose physical representations are (10, 2) and (15, 2). Without any other intervention, adding these two tensors would be a shape error. Levels help us understand that we should expand x to (10, 1, 2) and y to (1, 15, 2) prior to performing the (broadcasting) addition.

Integer levels can be problematic if they are allowed to escape from the lexical scope they are defined in:

box = [None]

def write(x):
    box[0] = x
    return x

def read(y):
    return box[0] + y

xs = torch.randn(10, 2)
ys = torch.randn(20, 10, 2)
print(vmap(xs, write))
print(vmap(ys, read))

What should the semantics of this program be? The batch tensor x escapes from the first vmap and shows up in the second vmap. It is clearly unrelated to the batch tensor y (whose batch size is 20, compared to x’s 10). But in both cases x and y are simply recorded as BatchTensor with level 0, and vmap cannot distinguish the two distinct vmap calls and will attempt to add (10, 2) and (20, 10, 2) directly, resulting in an error.

Dynamically allocated classes

Instead of recording integer levels, nested invocations of functorch transformations like vmap should instead generate fresh classes to wrap their arguments. Consider again the first vmap example:

vmap(xs, lambda x: vmap(ys, lambda y: x + y))

These two vmap invocations will result in the following dynamically generated class hierarchy (the names I have given are for purely expository reasons; in reality these classes are unnamed):

class BatchTensor0:  # from outer vmap
    wrapped: Tensor
    @staticmethod
    def __torch_dispatch__(...): ...

class BatchTensor1(BatchTensor0):  # from inner vmap
    wrapped: Union[Tensor, BatchTensor0]
    @staticmethod
    def __torch_dispatch__(...): ...

x will be a BatchTensor0, while y will be a BatchTensor1. We can now apply the multiple dispatch mechanism from __torch_function__ to determine which implementation of __torch_dispatch__ will be invoked in x + y. y is a subclass of x, and thus we will invoke BatchTensor1 (the subclass).

The key ingredient to this encoding is the subclass relation between BatchTensor1 and BatchTensor0. Without it, the multi dispatch is free to pick whichever class it wants to invoke __torch_dispatch__ on. With the subclass relation, it is obligated to handle BatchTensor1 first, which will do some expanding, unwrap its BatchTensor1 arguments and then redispatch (heading to BatchTensor0). If no BatchTensor1s are involved in an expression, we will short-circuit BatchTensor1 and immediately dispatch to BatchTensor0. Level corresponds to the depth of a class in the dynamically allocated class hierarchy.

We can also inspect what occurs in the lexical escape case:

class BatchTensor0:  # from first vmap
    wrapped: Tensor
    @staticmethod
    def __torch_dispatch__(...): ...

class BatchTensor0_1:  # from second vmap
    wrapped: Tensor
    @staticmethod
    def __torch_dispatch__(...): ...

Clearly these classes are unrelated. We can detect this case simply testing that each argument is a superclass of the current class whose __torch_dispatch__ is executing. As there’s no class relation here, this will (rightly) cause an error.

The presence of these dynamically allocated classes also suggests a clear API for imperative-style vmap:

with vmap() as BatchTensor0, vmap() as BatchTensor1:
    x = BatchTensor0(xs)
    y = BatchTensor1(ys)
    return x + y

What’s next?

There are two benefits to representing functorch levels in this way:

  • It removes the necessity for a separate functorch dispatcher; indeed, there is no need for a functorch dispatcher at all after this change; __torch_dispatch__ is all you need
  • It clearly explicates the semantics of non-lexical levels

There are some performance questions regarding the cost of dynamically creating classes in Python, and whether or not this design is compatible with compressed wrappers (e.g., BatchTensor can support arbitrarily many vmap levels, so long as they are contiguously leveled). However, this seems to be a clear win in terms of semantics, and we should ensure forward looking work obey these semantics.

Appendix: Python mode, no_dispatch

We need a sane API for no_dispatch in order to implement functorch’s nested grad transform. Functorch’s grad transform is a hybrid between a mode-dispatch key and a “Variable” wrapper: all new Tensors created need to be wrapped in these Variables to prevent funny business around Variables at different levels interacting. no_dispatch is necessary to prevent a mode-key from infinite recursion.

Concretely, no_dispatch should accept a tensor subclass and prevent dispatches for that tensor subclass.

Appendix: User education on type testing

Right now, most implementations of torch function/dispatch don’t do anything with the passed types array. This is bad: they should be testing if the types are what they expect (and if they are not, returning NotImplemented). This is going to be… hard to enforce.

Appendix: What happens if a tensor escapes the context manager?

We’ve solved the problem of levels being lexical (the level, which is a class, is invalid outside of the context manager! Compare this to before, when levels were just integers and were at danger of being re-used again).

Now, what happens if a tensor escapes the context manager? (what does the following return)?

def foo(xs):
    with vmap() as BatchTensor0:
        x = BatchTensor0(xs)
        return x

Ideally upon exiting the context manager, x just becomes a normal tensor. Unfortunately we can’t make this happen because x is literally an instance of BatchedTensor0 and we can’t change that in-place (unless we pursue some weak map idea).

There are a few possibilities here:

  1. x could become a “passthrough” BatchedTensor0 whose __torch_dispatch__ no longer calls batching rules and instead just invokes regular PyTorch operations. This can be done by directly mutating BatchedTensor0, replacing its torch dispatch with a new passthrough one.
  2. x could become a “dead” BatchedTensor0 that just yells whenever it gets used. If this is the case the user should really return x.elem().
3 Likes

Follow-up: Level ordering should NOT be done via subclassing

Previously, we proposed that subclassing could be used to mediate ordering for subclasses, leading to subclass hierarchies that look like this:

class BatchTensor0:  # from outer vmap
    wrapped: Tensor
    @staticmethod
    def __torch_dispatch__(...): ...

class BatchTensor1(BatchTensor0):  # from inner vmap
    wrapped: Union[Tensor, BatchTensor0]
    @staticmethod
    def __torch_dispatch__(...): ...

However, having BatchTensor1 inherit from BatchTensor0 is actually quite questionable. Remember that subclassing does multiple things:

  1. It affects the order in which we dispatch to classes (good!)
  2. You inherit fields from the superclass (bad!)
  3. You inherit methods from the superclass (bad!)
  4. It implies isinstance relationships (bad!)

There is really only one reason to subclass (ordering); all other things that happen due to subclassing are unwanted and lead to unintuitive behavior, like isinstance(BatchTensor1(), BatchTensor0) (did you really expect this?!) or like a super() call in BatchTensor1 going to BatchTensor0 (this is never what you want).

So what it really seems like is you want some unrelated mechanism to handle ordering. Amazingly, this can be done entirely in userland:

class BatchTensor0(BatchTensor):  # from outer vmap
    next_type: Type = Tensor
    wrapped: Tensor
    @classmethod
    @functorch_composable
    def __torch_dispatch__(cls, func, types):
        """
        if any(happens_before(t, cls) for t in types):
            return NotImplemented
        """
        ...

class BatchTensor1(BatchTensor):  # from inner vmap
    next_type: Type = BatchTensor0
    wrapped: Union[Tensor, BatchTensor0]
    @classmethod
    def __torch_dispatch__(cls, func, types):
        if any(happens_before(t, cls) for t in types):
            return NotImplemented
        ...

Essentially, each class participating in “functorch” does a scan through all types and tries to see if they are the child of all the other tensors participating in the operator call. If they are, that means that they are the latest level, and they should execute. Otherwise, they should return NotImplemented, and eventually we will dispatch to the correct __torch_dispatch__. This is less efficient than having been handled in C++ (because we have to useless dispatch to a bunch of __torch_dispatch__ before we get to the correct one) but maybe we can setup some sort of optimization for this case.

Follow-up 2: You don’t even need tests in torch dispatch

In Python Mode ([Reland] Add python mode by zou3519 · Pull Request #64360 · pytorch/pytorch · GitHub) Richard Zou adds support for a Python Mode that lets you override the meaning of all operators, without needing to have a data dependent flow into the type in question. Although the pull request only handles a single Python mode, the scheme generalizes to handle a stack of Python modes. Python modes participate in __torch_dispatch__ resolution by simply adding some more types to the set of overloaded types that we consider when doing multiple dispatch.

This means that we don’t even need to worry about making sure our classes run in the correct order: when we dispatch to them one-by-one via their modes, the modes handle making sure that we execute them in the right order. By the time we finish with all the modes there are no more functorch tensors around, and you can handle base behavior as you like.

This means that torch dispatch testing, while hypothetically it could be very intricate, in practice only needs to be good enough to resolve some simple precedence issues in mixed dispatch cases, and we can just assume the dispatch here just won’t be that complicated (and therefore avoid falling off the crazy complicated multi-dispatch OO cliff.)