TorchDispatchMode for Debugging, Testing, and More

Intro

Over the past few months since releasing and improving TorchDispatchMode, we’ve seen many users improve their workflows by using small, simple TorchDispatchModes. This note summarizes the basic background and usage of TorchDispatchMode. Additionally, it goes over some of the differences between modes and subclasses

What is Torch Dispatch Mode?

Torch Dispatch Mode gives user a way to interpose at the __torch_dispatch__ level on all calls, including factory functions. You may have seen other notes (or entire repos) about using a Tensor subclasses for similar purposes. Modes are similar to subclasses but they also let us capture factory functions, or ones that don’t take in tensor arguments, like torch.randn or torch.ones.

Basics

Here’s some code for a basic PrintingMode which just prints out every call that it sees

class PrintingMode(TorchDispatchMode):
  def __torch_dispatch__(self, func, types, args=(), kwargs=None):
    print(f"{func.__module__}.{func.__name__}({args}, {**kwargs})")
    return func(*args, **kwargs)

Stateful

Another benefit of modes over subclasses is that they can carry around state. So if you wanted to have the same PrintingMode but to have everything write to a logger object instead of just printing out, you could do

class LoggingMode(TorchDispatchMode):
  def __init__(self, logger):
    self.logger = logger
    return super().__init__()
  def __torch_dispatch__(self, func, types, args=(), kwargs=None):
    self.logger.log(f"{func.__module__}.{func.__name__}", args, kwargs)
    return func(*args, **kwargs)

Note here that the same mode always writes to the same Logger object. You could have multiple LoggerModes running at the same time with different logger objects

Also note that this requires an implemented Logger object. If you need an already working version, PyTorch already has a fully working [LoggingTensorMode]

Usage

Then, to use these modes, you wrap the calls that you want to log with:

with LoggingMode(logger):
  <call to be logged>

You may have seen others use something like

with enable_torch_dispatch_mode(LoggingMode(logger)):

In most cases these will do the same thing. Though, if you are adding mode to code that is already running with a mode, you’ll need to use with LoggingMode() (you’ll see an error saying this if you use the enable version)

Example Usages

We’ve already seen modes help debug these and other errors:

Additionally, it’s being used in testing of:

When should I use a Tensor Subclass? When should I use a mode? When should I use both?

I’ll start by noting there’s some subtlety here and often either/or would work (before modes, we had been using subclasses to solve similar problems). Personally, I think modes are faster to write with fewer gotchas, so I would recommend starting there and adding in a subclass if necessary.

After that disclaimer, the basics are this: you should use a mode if you just want to see every function that hits __torch_dispatch__. You should use a subclass if you care to follow tensor arguments around to different calls or propagate tensors specific state, like dispatch keys. Let’s break this down a little more:

Mode

Here, consider the debugging modes that we just saw. For these, we only really care about the function being called and the arguments being passed to it. We don’t care to track the tensors being called. Additionally, we intend to run all of the functions with the tensors that were passed in. So a mode is sufficient here

Subclass

Examples of things that need to be subclasses are our ProxyTensors and FakeTensors.

With ProxyTensors, the basis of AOTAutograd, we want to know when the same argument is used twice in order to give us correct and succinct graphs. Here we want the graph to use the same Proxy objects so we need a subclass

Fake Tensors let us avoid the cost of actual computation while we’re tracing. In order to not run computations, we need to use the meta kernel, meaning we need the meta dispatch key on the tensor without requiring that on a user’s tensor. So we need to propagate the subclass through to correctly set dispatch keys and device types throughout the computation

As a note, both of these subclasses also use a mode to capture factory functions and wrap them in the subclass. We believe that this will typically be the intended behavior and generally recommend using a mode with your subclass. And have some best practices for doing it

Extension points other than __torch_dispatch__?

Yes! All of this exists for __torch_function__ too. We haven’t seen as many use cases of this but it works exactly the same as dispatch modes, just with TorchFunctionMode and a __torch_function__ implementation.

Inspired to try out modes?

:partying_face: If you run into any bugs :bug:, please file a bug to pytorch/pytorch!

Something missing? Also feel free to file and issue. This note focused on the most basic usage of modes and leaves out a lot of advanced features to not overwhelm. We’ve seen the benefits from using basic modes and wanted to start there but are excited to talk about the advanced usages!

6 Likes