What (and Why) is __torch_dispatch__?

With Alban Desmaison, Edward Yang, and Richard Zou.

You may have seen us mention __torch_dispatch__ in various places recently. For example:

But why have so many cool new things popped up with __torch_dispatch__ ? And why are we (PyTorch) doing it at all in the first place?

To answer that, let’s bring it back to the beginning.

TL;DR: torch_dispatch allows you to arbitrarily extend PyTorch with all the power of the dispatcher, but now from Python. This will hopefully open up a whole new frontier of flexibility for PyTorch, all in Python.

What is the core of PyTorch? (Spoiler: the dispatcher)

At a high level, there are 2 things that PyTorch does.

  1. Depending on the inputs, figure out an appropriate kernel to run, and whether that’s a CUDA or a CPU implementation.
  2. Depending on the inputs, register the appropriate things in the autograd graph.

These things are what turn PyTorch from just “numpy” into “Numpy with CUDA support and autograd”. Crucially, these 2 things are both facilitated in PyTorch through the dispatcher .

At its core, the dispatcher is a system that, depending on properties of its input, decides which function it should call. To understand more, I suggest reading this excellent post by Edward Yang

For example, let’s say we have something like aten::sin(Tensor) . What actually happens? Well, first, we check whether Tensor requires grad on it. If so, we call aten::sin_with_backward (not a real op, but morally this builds the backwards pass). Then, if Tensor is on CUDA, then we dispatch to aten::sin_with_backward_cuda. Other than autograd, things like Automatic Mixed Precision or vmap are also implemented with the dispatcher.

In essence, the dispatcher is responsible for the core of the functionality that PyTorch provides. In addition, due to its central role in PyTorch, it allows deep integration into the framework in a manner that no other approach can provide. As a result, the dispatcher is also one of the central locations to extend PyTorch functionality. For example, Functorch’s vmap can transparently work with nearly all of the functionality in PyTorch (including autograd). Why? Because it lives in the dispatcher.

Let’s go into some examples of what you might want to do with a dispatch system.

Why have a dispatcher system?

If you think about it, it’s actually somewhat remarkable what PyTorch does. You can take normal Python code, and then, with no modifications except setting requires_grad on an input, it can do something totally different - compute gradients!

The PyTorch dispatcher is an example implementation of the underlying system of dynamic dispatch. For example, let’s take one way device dispatch could be implemented

def sin(x: Tensor):
    if x.device == 'cuda':
        return sin_cuda(x)
    else:
        return sin_cpu(x)

Other than being quite ugly, this opens up another composability issue - we can’t extend sin without modifying the actual function implementation itself! For example, let’s say we were to add vmap. Do we add another conditional inside of the function?

So, instead, we allow the dispatcher to decide which implementation of sin to dispatch to based off of properties of the input. Now we have something like this.

def sin(x: Tensor[requires_grad=False]): return sin_without_grad(x)
def sin(x: Tensor[requires_grad=True]): return sin_with_grad(x)
def sin(x: Tensor[is_batched=True]): return sin_batched(x)

Even better, in many situations, we can actually reuse other implementations. For example, sin_with_grad probably still ends up calling sin_without_grad somewhere in it. For example, perhaps it looks something like

def sin(x: Tensor[requires_grad=True]):
    no_grad_x = x.requires_grad(False)
    out: Tensor[requires_grad=True] = sin(no_grad_x: Tensor[requires_grad=False])
    out.register_backwards_function(sin)
    return out.requires_grad(True)

As a side note, it could make more sense to think of this special behavior as wrapper subclasses instead of attributes on the tensor. So, the above examples might look like

def sin(x: Tensor) # Base tensor, just calls sin
def sin(x: GradTensor(Tensor)): # Wrapper gradient tensor that tracks graadients
def sin(x: BatchedTensor(Tensor)): # Wrapper batched tensor that performs vmap

In fact, many other kinds of functionality can be implemented in this manner! Logging, tracing, flop counting, vmap, diagonal tensor, masked tensor, etc! For example (just pseudocode)

Flop Counting

flop_count = 0
def sin(x: FlopTensor(Tensor)):
    unwrap_x: Tensor = x.elem  # Unwraps FlopTensor to get the underlying Tensor
    flop_count += get_sin_flops(x.shape)  # Counts flops
    out = sin(unwrap_x)  # Calls sin on the unwrapped tensor (i.e. redispatches)
    return FlopTensor(out)  

Tracer

def ProxyTensor(Tensor):
    elem: Tensor
    proxy: Proxy
    
def sin(x: ProxyTensor(Tensor)):
   proxy = x.proxy
   unwrap_x = x.elem
   out = sin(unwrap_x)
   proxy_out = proxy.call_function('sin')   
   return ProxyTensor(out, proxy_out)

Basically, the dispatcher allows you to do all sorts of cool stuff, and override all sorts of PyTorch behavior in a composable manner. But… it came along with a lot of restrictions. For one, it’s … difficult to register new functionality to the dispatcher without talking to the PyTorch core team. But two, registering this functionality needed to be done in C++!

So, as a high level goal, __torch_dispatch__ allows you to leverage all the power of the dispatcher, but from Python!

Why is torch_dispatch important?

Let’s take a look at a typical flow for calling an operator in PyTorch, as well as the various points where we can modify behavior.

This is a diagram of how torch_dispatch might work with vmap. The black arrows represent paths taken, the dotted arrows represent paths that could have been taken, depending on the dispatch keys.

Note how, 1. __torch_dispatch__ sits after vmap behavior (and thus can capture it), and 2. __torch_dispatch__ is the only way to go from C++ back into Python.

As an oversimplification, nearly every extension point (with a couple exceptions) in PyTorch today is done even before step 1. What that means is that to some extent, none of this functionality can have any idea what’s happening below the scenes! This inhibits a lot of potential functionality.

As a simple example, take FLOP counting. The first FLOP counters in PyTorch were all implemented above the framework - they would be implemented at the module level. This kinda worked, but broke down as soon as users used non-standard modules or did things within their modules. Next, folks used approaches implemented things within the PyTorch framework, but above C++ (i.e. torch_function and FX), which allowed them to capture operators within modules. But… these approaches never worked for 1. capturing the backwards pass, and 2. would definitely never work for capturing things like Jacobian or Hessian flop count.

It is only by integrating inside of the C+ dispatcher with __torch_dispatch__ , that we created a FLOP counter that could capture backwards FLOPS .

Basically, __torch_function__ merely allows you to modify what happens in Python, but if you want control over everything that’s happening in PyTorch? You need to use __torch_dispatch__

So, how does it look like?

Let’s take a look at a simple example, for example, say you wanted to replace every instance of aten::add with aten::sub .

class FooTensor(torch.Tensor):

def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
   # First, we must unwrap the wrapper tensors to get the inner tensor object
   def unwrap(x):
        return x.elem if isinstance(x, FooTensor) else x
        
   args = tree_map(unwrap, args)
   kwargs = tree_map(unwrap, kwargs)
   # Now, we check the function to determine how to handle it. If it's 
   # aten.add, then we call aten.sub. Otherwise, we pass through to 
   # the original function
   if func == torch.ops.aten.add:
       out = torch.ops.aten.sub(*args, **kwargs)
   else:
       out = func(*args, **kwargs)
   
   # Now, we want to continue propagating this tensor, so we rewrap Tensors in
   # our custom tensor subclass
   def wrap(x):
       return FooTensor(x) if isinstance(x, Tensor) else x
       
   return tree_map(wrap, out)

As you can see, __torch_dispatch__ gives an immense amount of flexibility. For each aten operator, we are able to do whatever we want to it, including:

  1. Do things before the operator (including logging or actual modification of values)
  2. Do things after the operator (same as above).
  3. Call an entirely arbitrary implementation of a function (e.g. calling Numpy or another compiler).
  4. Redispatch back to the default implementation.

Note that 3 - the ability to call an entirely arbitrary implementation of a function, also gives us an immense amount of flexibility when it comes to actual tensor representation. For example, we could represent the tensor as an Int8 quantized tensor, dequantize the tensor, and then call the original function.

def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
    def unwrap(e):
        if isinstance(e, QuantTensor):
            return cls.dequantize(e.mat, e.row_factor, e.column_factor, e.requires_grad, e.dtype)
        else:
            return e
    out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))

I’ve skipped many details for the sake of a clearer example. So, if you want to actually try out __torch_dispatch__ or learn more about the details, check out https://github.com/albanD/subclass_zoo.

What is the long term vision?

Let’s take a ResNet18 model written in PyTorch. What is that program? One way to view it is that it’s simply a high level representation of assembly code. But… it doesn’t just map to a single sequence of instructions. Depending on the inputs, it could run on CPU, GPU, or TPU. Depending on whether we require grad, it might save activations for the backwards pass or it might not. Depending on whether we activate Automatic Mixed Precision, it might automatically cast tensors between float32 and float16 or it might not.

Perhaps a better way to view it is as an abstract representation of our model. Much as a mathematical formula can be translated into code, PyTorch translates this abstract representation of our model into a gazillion different actually executed pieces of code. But we could do much more than the examples above.

While keeping the modeling code identical, using __torch_dispatch__ , users should be able to

And moreover, these should compose . They should be able to compute per-sample gradients of MaskedTensors whose mask is represented as a diagonal matrix, parallelize this using tensor/model/data parallelism, and then trace out the whole thing so we can pass it to a compiler.

None of these things were impossible to do before, they just required significant investment from PyTorch core. __torch_dispatch__ just opens that up to the people.

PS: I’ll note that this last point, about tracing, is likely to be fairly important. All of these tensor subclasses we want to do are just abstractions over the underlying kernel operations we want to execute. But, tracing allows us to trace through all of these layers of abstraction, and get at the underlying tensor operations.

So, as an example, we could (hypothetically) add a custom “4-bit butterfly sparsity tensor”, and as long as all of their underlying operations are ATen operators, they could train/evaluate with this tensor (all in Python!), and then, trace out the tensor semantics and export to mobile! Note that this isn’t even a hypothetical, this works today :slight_smile:

Other Resources

Subclass zoo: GitHub - albanD/subclass_zoo

22 Likes

Hi Horace. Great post!
The link you posted at the end:

4 Likes

While learning the source code of make_fx I encounter three functions:

no_python_dispatcher = torch._C._DisablePythonDispatcher
enable_python_dispatcher = torch._C._EnablePythonDispatcher
enable_pre_dispatch = torch._C._EnablePreDispatch

I guess they might be a set of important functions related to the dispatcher. After searching the codebase, I still can’t summarize how they worked exactly. Could you provide some explanation of these functions? @Chillee

Fantastic tutorial and explanation on torch_dispatch