With Alban Desmaison, Edward Yang, and Richard Zou.

You may have seen us mention `__torch_dispatch__`

in various places recently. For example:

- Making a Flop Counter that counts backwards FLOPS
- [Building a tracing frontend to make it easy to use compilers in the backwards pass]Min-cut optimal(*) recomputation (i.e. activation checkpointing) with AOTAutograd)

But why have so many cool new things popped up with `__torch_dispatch__`

? And why are we (PyTorch) doing it at all in the first place?

To answer that, letâ€™s bring it back to the beginning.

**TL;DR: torch_dispatch allows you to arbitrarily extend PyTorch with all the power of the dispatcher, but now from Python. This will hopefully open up a whole new frontier of flexibility for PyTorch, all in Python.**

## What is the core of PyTorch? (Spoiler: the dispatcher)

At a high level, there are 2 things that PyTorch does.

- Depending on the inputs, figure out an appropriate kernel to run, and whether thatâ€™s a CUDA or a CPU implementation.
- Depending on the inputs, register the appropriate things in the autograd graph.

These things are what turn PyTorch from just â€śnumpyâ€ť into â€śNumpy with CUDA support and autogradâ€ť. Crucially, these 2 things are both facilitated in PyTorch through the **dispatcher** .

At its core, the dispatcher is a system that, depending on properties of its input, decides which function it should call. To understand more, I suggest reading this excellent post by Edward Yang

For example, letâ€™s say we have something like `aten::sin(Tensor)`

. What actually happens? Well, first, we check whether `Tensor`

requires grad on it. If so, we call `aten::sin_with_backward`

(not a real op, but morally this builds the backwards pass). Then, if `Tensor`

is on CUDA, then we dispatch to `aten::sin_with_backward_cuda.`

Other than autograd, things like Automatic Mixed Precision or vmap are also implemented with the dispatcher.

In essence, **the dispatcher is responsible for the core of the functionality that PyTorch provides.** In addition, due to its central role in PyTorch, it allows deep integration into the framework in a manner that no other approach can provide. As a result, **the dispatcher is also one of the central locations to extend PyTorch functionality.** For example, Functorchâ€™s vmap can transparently work with nearly all of the functionality in PyTorch (including autograd). Why? Because it lives in the dispatcher.

Letâ€™s go into some examples of what you might want to do with a dispatch system.

## Why have a dispatcher system?

If you think about it, itâ€™s actually somewhat remarkable what PyTorch does. You can take normal Python code, and then, with no modifications except setting `requires_grad`

on an input, it can do something totally different - compute gradients!

The PyTorch dispatcher is an example implementation of the underlying system of dynamic dispatch. For example, letâ€™s take one way device dispatch could be implemented

```
def sin(x: Tensor):
if x.device == 'cuda':
return sin_cuda(x)
else:
return sin_cpu(x)
```

Other than being quite ugly, this opens up another composability issue - we canâ€™t extend `sin`

without modifying the actual function implementation itself! For example, letâ€™s say we were to add vmap. Do we add another conditional inside of the function?

So, instead, we allow the dispatcher to decide which implementation of `sin`

to dispatch to based off of properties of the input. Now we have something like this.

```
def sin(x: Tensor[requires_grad=False]): return sin_without_grad(x)
def sin(x: Tensor[requires_grad=True]): return sin_with_grad(x)
def sin(x: Tensor[is_batched=True]): return sin_batched(x)
```

Even better, in many situations, we can actually reuse other implementations. For example, `sin_with_grad`

probably still ends up calling `sin_without_grad`

somewhere in it. For example, perhaps it looks something like

```
def sin(x: Tensor[requires_grad=True]):
no_grad_x = x.requires_grad(False)
out: Tensor[requires_grad=True] = sin(no_grad_x: Tensor[requires_grad=False])
out.register_backwards_function(sin)
return out.requires_grad(True)
```

As a side note, it could make more sense to think of this special behavior as wrapper subclasses instead of attributes on the tensor. So, the above examples might look like

```
def sin(x: Tensor) # Base tensor, just calls sin
def sin(x: GradTensor(Tensor)): # Wrapper gradient tensor that tracks graadients
def sin(x: BatchedTensor(Tensor)): # Wrapper batched tensor that performs vmap
```

In fact, many other kinds of functionality can be implemented in this manner! Logging, tracing, flop counting, vmap, diagonal tensor, masked tensor, etc! For example (just pseudocode)

Flop Counting

```
flop_count = 0
def sin(x: FlopTensor(Tensor)):
unwrap_x: Tensor = x.elem # Unwraps FlopTensor to get the underlying Tensor
flop_count += get_sin_flops(x.shape) # Counts flops
out = sin(unwrap_x) # Calls sin on the unwrapped tensor (i.e. redispatches)
return FlopTensor(out)
```

Tracer

```
def ProxyTensor(Tensor):
elem: Tensor
proxy: Proxy
def sin(x: ProxyTensor(Tensor)):
proxy = x.proxy
unwrap_x = x.elem
out = sin(unwrap_x)
proxy_out = proxy.call_function('sin')
return ProxyTensor(out, proxy_out)
```

Basically, the dispatcher allows you to do all sorts of cool stuff, and override all sorts of PyTorch behavior in a composable manner. Butâ€¦ it came along with a lot of restrictions. For one, itâ€™s â€¦ difficult to register new functionality to the dispatcher without talking to the PyTorch core team. But two, registering this functionality needed to be done in C++!

So, as a high level goal, `__torch_dispatch__`

**allows you to leverage all the power of the dispatcher, but from Python!**

##
Why is **torch_dispatch** important?

Letâ€™s take a look at a typical flow for calling an operator in PyTorch, as well as the various points where we can modify behavior.

This is a diagram of how **torch_dispatch** might work with vmap. The black arrows represent paths taken, the dotted arrows represent paths that *could* have been taken, depending on the dispatch keys.

Note how, 1. `__torch_dispatch__`

sits after vmap behavior (and thus can capture it), and 2. `__torch_dispatch__`

is the only way to go from C++ back into Python.

As an oversimplification, nearly every extension point (with a couple exceptions) in PyTorch today is done **even before step 1.** What that means is that to some extent, none of this functionality can have any idea whatâ€™s happening below the scenes! This inhibits a lot of potential functionality.

As a simple example, take FLOP counting. The first FLOP counters in PyTorch were all implemented above the framework - they would be implemented at the module level. This kinda worked, but broke down as soon as users used non-standard modules or did things *within* their modules. Next, folks used approaches implemented things within the PyTorch framework, but above C++ (i.e. **torch_function** and FX), which allowed them to capture operators *within* modules. Butâ€¦ these approaches never worked for 1. capturing the backwards pass, and 2. would definitely never work for capturing things like Jacobian or Hessian flop count.

It is only by integrating inside of the C+ dispatcher with `__torch_dispatch__`

, that we created a FLOP counter that could capture backwards FLOPS .

**Basically,** `__torch_function__`

**merely allows you to modify what happens in Python, but if you want control over everything thatâ€™s happening in PyTorch? You need to use** `__torch_dispatch__`

## So, how does it look like?

Letâ€™s take a look at a simple example, for example, say you wanted to replace every instance of `aten::add`

with `aten::sub`

.

```
class FooTensor(torch.Tensor):
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# First, we must unwrap the wrapper tensors to get the inner tensor object
def unwrap(x):
return x.elem if isinstance(x, FooTensor) else x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# Now, we check the function to determine how to handle it. If it's
# aten.add, then we call aten.sub. Otherwise, we pass through to
# the original function
if func == torch.ops.aten.add:
out = torch.ops.aten.sub(*args, **kwargs)
else:
out = func(*args, **kwargs)
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return FooTensor(x) if isinstance(x, Tensor) else x
return tree_map(wrap, out)
```

As you can see, `__torch_dispatch__`

gives an immense amount of flexibility. For each aten operator, we are able to do whatever we want to it, including:

- Do things before the operator (including logging or actual modification of values)
- Do things after the operator (same as above).
- Call an entirely arbitrary implementation of a function (e.g. calling Numpy or another compiler).
- Redispatch back to the default implementation.

Note that 3 - the ability to call an entirely arbitrary implementation of a function, also gives us an immense amount of flexibility when it comes to actual tensor representation. For example, we could represent the tensor as an Int8 quantized tensor, dequantize the tensor, and then call the original function.

```
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
if isinstance(e, QuantTensor):
return cls.dequantize(e.mat, e.row_factor, e.column_factor, e.requires_grad, e.dtype)
else:
return e
out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
```

Iâ€™ve skipped many details for the sake of a clearer example. So, if you want to actually try out `__torch_dispatch__`

or learn more about the details, check out https://github.com/albanD/subclass_zoo.

## What is the long term vision?

Letâ€™s take a ResNet18 model written in PyTorch. What is that program? One way to view it is that itâ€™s simply a high level representation of assembly code. Butâ€¦ it doesnâ€™t just map to a single sequence of instructions. Depending on the inputs, it could run on CPU, GPU, or TPU. Depending on whether we require grad, it might save activations for the backwards pass or it might not. Depending on whether we activate Automatic Mixed Precision, it might automatically cast tensors between float32 and float16 or it might not.

Perhaps a better way to view it is as an abstract representation of our model. Much as a mathematical formula can be translated into code, PyTorch translates this abstract representation of our model into a gazillion different actually executed pieces of code. But we could do much more than the examples above.

While keeping the modeling code identical, using `__torch_dispatch__`

, users should be able to

- Compute efficient per-sample gradients
- Train 10 copies of Resnet18 efficiently at the same time
- Compute the FLOPS
- Parallelize it across arbitrary number of devices using arbitrary horizontal/vertical parallelism
- Represent the weights in some arbitrary more efficient representation
- Low rank approximation
- Butterfly sparsity
- Factorized compression
- Taichi sparsity data structure
- 8-bit quantized format
- Diagonal Tensors
- Linear Operator Tensors
- Arbitrary Einsum Tensors

- Pass MaskedTensors in as inputs
- Keep the tensors on SSD until theyâ€™re needed, and then load them SSD when theyâ€™re used
- Execute in some arbitrary lazy execution manner (like LazyTensor)
- Trace out the operations occurring in the graph (i.e. AOTAutograd)
- A billion more things we havenâ€™t thought of yet.

And moreover, these should **compose** . They should be able to compute per-sample gradients of MaskedTensors whose mask is represented as a diagonal matrix, parallelize this using tensor/model/data parallelism, and then trace out the whole thing so we can pass it to a compiler.

None of these things were *impossible* to do before, they just required significant investment from PyTorch core. `__torch_dispatch__`

just opens that up to the people.

PS: Iâ€™ll note that this last point, about tracing, is likely to be fairly important. All of these tensor subclasses we want to do are just abstractions over the underlying kernel operations we want to execute. But, tracing allows us to trace through *all* of these layers of abstraction, and get at the underlying tensor operations.

So, as an example, we could (hypothetically) add a custom â€ś4-bit butterfly sparsity tensorâ€ť, and as long as all of their underlying operations are ATen operators, they could train/evaluate with this tensor (all in Python!), and then, trace out the tensor semantics and export to mobile! Note that this isnâ€™t even a hypothetical, this works today

## Other Resources

Subclass zoo: GitHub - albanD/subclass_zoo