Higher Order Operators, 2023/10

Higher Order Operators, 2023/10

What do autograd.Function, torch.utils.checkpoint, torch.cond, triton kernels, and torch.vmap all have in common? Answer: torch.compile support for the above APIs (and more!) go through the Higher Order Operator (HOP) mechanism.

What is a HOP?

An operator in PyTorch (e.g. aten::sum pictured) is a function with:

  • a schema that defines the acceptable input and output types
  • a set of optional implementations for backends and functionalities and in PyTorch (e.g. CPU, CUDA, Autograd).

A “higher-order” operator (HOP) is an operator that:

  • either accepts a Python function as input, returns a Python function as output, or both.
  • like all PyTorch operators, higher-order operators also have an optional implementation for backends and functionalities. This lets us e.g. register an autograd formula for the higher-order operator or define how the higher-order operator behaves under ProxyTensor tracing.

Throughout the last year, we have created common infrastructure for (1) defining HOPs and (2) getting them to work with torch.compile. The team has leveraged HOPs to add torch.compile support for a number of existing APIs (autograd.Function, torch.utils.checkpoint) with many others (Triton kernels, torch.vmap) coming on the way.

Background: torch.compile and operators

The torch.compile stack consists of three components: Dynamo, AOTDispatcher (sometimes known as AOTAutograd), and Inductor. At a high level, the responsibility of each component in torch.compile(func)(*args) is:

  • Dynamo is responsible for taking func and capturing and partitioning it into “traceable” subgraphs that work with the rest of the stack. Every time it discovers a subgraph, it passes it to AOTDispatcher
  • AOTDispatcher takes the subgraph, functionalizes it, normalizes it into ATen IR, and also handles autograd if necessary. The normalized subgraph is then passed to Inductor.
  • Inductor takes the subgraph and generates fast device-specific kernels.

Built-in PyTorch operators pass through this stack. For e.g. torch.add: Dynamo will capture a call to torch.add in the graph, AOTDispatcher will normalize it into an aten::add call, and Inductor will generate efficient code for it.

What should torch.compile do when it sees an API that accepts a function? Let’s take torch.utils.checkpoint.checkpoint as an example. checkpoint(fun, *args, **kwargs) runs fun(*args, **kwargs), but does not save any intermediate Tensors for backwards. During the backwards pass, it recomputes the necessary intermediates for backward by re-running fun(*args, **kwargs).

When Dynamo sees torch.utils.checkpoint.checkpoint(fun, *args, **kwargs), it cannot directly record it into the graph to be passed to AOTDispatcher:

  • fun may not be “traceable”: that is, it may have side effects like print statements, or it may update global variables
  • fun may have free variables that can refer to objects not currently being tracked by torch.compile!

All HOPs share some common infrastructure to resolve these Dynamo-related issues and then have customizable handling to go through AOTDispatcher and Inductor.

Getting a HOP E2E through the torch.compile stack

To demonstrate how HOPs work, let’s discuss a simple HOP, torch.ops.higher_order.wrap (NB: this has nothing to do with torch.fx.wrap). The semantics of wrap(func, *args) are equivalent to the following Python function.

def wrap(func, *args):
    return func(*args)

What happens when we torch.compile a call to torch.ops.higher_order.wrap, like the following?

@torch.compile
def f(x, y):
    def body(x):
        return x + y
    result = torch.ops.higher_order.wrap(body, x)
    return result

Dynamo capture

Dynamo’s responsibility is to capture a graph that works with the rest of the PT2 stack. It is not safe for it to always insert a call to wrap(body, x) into the graph it generates, because the user-defined body may do things that are “not compilable”: the body may have side effects like print statements or it may update global variables, both of which cannot be compiled.

When Dynamo (the PT2 frontend) sees a call to wrap(func, *args), it:

  • attempts to trace func into a single subgraph that has no side-effects
  • lifts any free variables of func into being inputs of the traced subgraph
  • rewrites and records the call to wrap in the output graph

That is, Dynamo will trace the above code as the following (pseudocode):

def f_graph(self, x, y):
    result = torch.ops.higher_order.wrap(self.body_graph, x, y)
    return result

# NB: `body_graph`, unlike `body`, accepts y as an arg
def body_graph(self, x, y):
    add = x + y
    return add

If Dynamo is unable to trace the body into a single graph (because it contains graph breaks, or if there are non-local side effects in the body), then we call the body “unsafe” and Dynamo skips compilation and falls back to eager-mode PyTorch. All HOPs share the mechanism described above and therefore also these limitations.

AOTDispatcher integration

A HOP can customize its behavior with AOTDispatcher by going through PyTorch Dispatcher extension points. Concretely, here is where a HOP can define its behavior underneath autograd, and how it should trace in the graph that gets passed to Inductor.

We’ve specified torch.ops.higher_order.wrap(fun, args)’s behavior under AOTDispatcher to just decompose to fun(*args), so AOTAutograd produces the following graph for consumption with Inductor:

def f_graph(self, x, y):
    result = torch.ops.aten.add.Tensor(x, y)
    return result

Note that Inductor never ends up seeing the wrap (but instead sees what it decomposes into). This is the case for most existing HOPs: Inductor never sees them but instead sees what they decompose into.

HOPs Case Studies

In this section we highlight some work leveraging the HOP mechanism.

torch.utils.checkpoint.checkpoint

checkpoint(fun, *args, **kwargs) runs fun(*args, **kwargs), but does not save any intermediate Tensors for backwards. During the backwards pass, it recomputes the necessary intermediates for backward by re-running fun(*args, **kwargs).

When Dynamo sees checkpoint(fun, *args, **kwargs), it uses the HOP mechanism to trace fun into a subgraph. If this succeeds, we insert a call to a new torch.ops.higher_order.tag_activation_checkpoint(fun, *args, **kwargs) op into the Dynamo-produced graph. AOTDispatcher sees this operator and has special handling for it.

Going through a quick example, let’s say we have the following code:

@torch.compile
def f(x):
    out = torch.utils.checkpoint.checkpoint(g, x)

def g():
    w = x.sin()
    z = w.sin()
    return z

x = torch.randn([], requires_grad=True)
y = f(x)

Dynamo traces the code into the following:

def f_graph(self, x):
    g_graph = self.g_graph
    out = torch.ops.higher_order.tag_activation_checkpoint(g_graph, x)

def g_graph(self, x):
    w = x.sin()
    z = w.sin()
    return z

AOTDispatcher produces two graphs, a forward graph and a backward graph. It has special handling for tag_activation_checkpoint that results in recomputation of the nodes in the backwards pass:

def f_forward(self, x):
    w = x.sin()
    z = w.sin()
    return z

def f_backward(self, grad_z, x):
    w = x.sin()
    grad_w = grad_z * w.cos()
    grad_x = grad_w * x.cos()
    return grad_x

Finally, Inductor will optimize these two graphs passed to it.

Please see anijain2305’s PR over at https://github.com/pytorch/pytorch/pull/101028 for more details.

autograd.Function

Users use autograd.Function to specify a custom backwards formula for a sequence of operations in PyTorch. An autograd.Function consists of two user-defined functions: a forward() and a backward().

When Dynamo sees an autograd.Function during training, it will introspect the forward() for safety, the backward() for safety, and finally insert a call to the autograd.Function’s apply method into the produced FX Graph. AOTDispatcher then inlines the autograd.Function’s forward/backward directly into the forward and backward graphs it traces, which get passed to Inductor.

Please see voznesenskym’s PR over at https://github.com/pytorch/pytorch/pull/99483 for more details.

Control flow operators: torch.cond

Historically, torch.cond was our first HOP (which also came with the initial HOP mechanism!), implemented and designed by voznesenskym(https://github.com/pytorch/pytorch/pull/83154), and was further improved upon this year by ydw4.

The motivation behind it is: Dynamo produces a graph break when it sees data-dependent control flow. Concretely, it is unable to trace the following into a full graph.

@torch.compile(backend="eager", fullgraph=True)
def f(x):
    if x > 0:
        return x.sin()
    else:
        return x.cos()

x = torch.tensor(1.)
f(x)

This is because Dynamo traces a function by running through it with FakeTensors (tensors without storage), and so it is unable to determine if x (as a FakeTensor) is greater than 0 and it ends up falling back to eager-mode PyTorch.

Our current recommendation for users who encounter this situation is to rewrite their function to use torch.cond. Semantically, torch.cond performs the following:

def cond(pred, true_fn, false_fn, args):
    if pred:
        return true_fn(*args)
    else:
        return false_fn(*args)

When Dynamo sees a torch.cond, it will:

  • trace both the true_fn and false_fn and turn them into subgraphs
  • emit a call to torch.ops.higher_order.cond(pred, true_fn_graph, false_fn_graph, args) into the graph.

However, unlike the other HOPs above, but more similar to built-in PyTorch operators, cond also has:

  • a functionalization rule to handle local mutations in true_fn and false_fn`.
  • a ProxyMode rule that specifies what happens when it gets traced with ProxyMode
  • a FakeTensorMode rule that specifies what happens when it gets passed a FakeTensor

Coming soon: Triton kernels, hardening, and more

Triton kernels

Triton is an easy way to author performant GPU kernels in Python. Users already use triton to author their custom kernels (e.g. flash attention) and we expect to see more in the future; torch.compile’s Inductor backend also leverages triton to generate efficient fused kernels.

We would like user-defined triton kernels to integrate with the torch.compile stack. That is, torch.compile should be able to optimize calls to user-defined triton kernels and improve performance via:

  • no graph breaks. Previously, torch.compile would graph break when it saw a triton kernel.
  • ahead-of-time compilation and autotuning of triton kernels
  • enabling optimizations like fusion between the user-defined triton kernel and triton code generated by Inductor

oulgen’s work on integrating triton kernels with torch.compile involves automatically wrapping the triton kernel into a HOP and then shepherding it through the torch.compile stack.

  • When Dynamo sees a call to a triton kernel, it replaces it with a call to a triton_wrapper_mutation(triton_kernel, grid, kwargs) HOP.
  • Triton kernels mutate their inputs. We added a functionalization rule so that AOTDispatcher converts a call to triton_wrapper_mutation to a call to a triton_wrapper_functional HOP
  • Inductor receives the triton_wrapper_functional HOP and figures out how to optimize and compile it in parallel with other inductor emitted kernels.

Follow along with the implementation over at https://github.com/pytorch/pytorch/pull/109623.

Hardening

autograd.Function is the most popular API being backed by the HOP mechanism. We’ve gotten numerous bug reports from internal users (and we expect to see many on the OSS side after the PyTorch 2.1 release) that underneath torch.compile, not all autograd.Function are capturable with the HOP mechanism and fall back to eager-mode PyTorch.

Many of these are self-imposed limitations in the HOP mechanism that should be liftable for all HOPs: e.g. we do not yet support HOPs that return more than one Tensor or accept inputs that are not a flat list of Tensors.

Backward hooks

Users can add backward hooks to Tensors that will trigger when autograd computes gradients. These hooks can close over arbitrary variables and mutate arbitrary Python state that may not be available during a forwards pass. In voznesenskym’s work on tracing backward hooks, we plan to wrap invocations of backward hooks in a special trace_wrapped HOP. Later on, when Dynamo sees a .backward() call, it will resolve the trace_wrapped HOPs by introspecting the hooks.

Learn more over at https://github.com/pytorch/pytorch/pull/109690

(stretch) Even more control flow ops

PyTorch has a torch.cond (if-statement) control flow operator. torch.scan (a restricted for-loop) is a common user request, and teams have asked for a torch.while_loop.

At the end of the day, our control flow operators execute in Python: torch.cond ultimately gets executed as a Python if-statement after it makes it through the torch.compile stack. Users have asked for control flow operators that are lowered to device-side code (e.g. CUDA kernels); a device-side scan could allow researchers to experiment with recurrent architectures without being bitten by the Python overhead that traditionally dominates in such models.

(stretch) functorch transforms

functorch (now torch.func) offers function transforms like vmap and grad that users can compose to easily produce advanced autodiff quantities such as jacobians and hessians. These function transform APIs accept a function as an input and return a function but do not yet work with torch.compile, making the HOP mechanism a good fit for them. kshitij12345 worked on a prototype of torch.vmap and torch.func.grad support for torch.compile that we will continue pushing forward.

Acknowledgements

Thanks to the following people for developing and improving the core HOP mechanism:

Thanks to the following additional folks for feedback on an earlier version of this note: gchanan, supriyar, oulgen

11 Likes

Thank you for this wonderful introduction. I am wondering if the function ‘f’ in the example should have one more argument, the ‘g’. Is it a typo or it’ ok to write it like this?

@torch.compile
def f(x):
    out = torch.utils.checkpoint.checkpoint(g, x)

def g():
    w = x.sin()
    z = w.sin()
    return z

x = torch.randn([], requires_grad=True)
y = f(x)