Custom Ops Under torch.compile: autograd.Function vs torch.library.custom_op

When integrating custom kernels (CUDA, Triton, etc.) into a PyTorch training loop,
you need to tell PyTorch how to run the kernel and how to differentiate it.
Two APIs exist for this, and they interact with torch.compile very differently.

For the full reference, see the Custom Ops landing page and the Custom Ops manual.

API 1: torch.autograd.Function

Note: This is the most widely used API, but it is not the recommended
one if you need torch.compile integration. Most users should migrate to
torch.library.custom_op described below.

The classic approach – define forward and backward in a single class:

class MyOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v):
        out, lse, S = my_kernel_fwd(q, k, v)
        ctx.save_for_backward(q, k, v, lse, S)
        return out, lse, S  # intermediates must be returned -- see below

    @staticmethod
    def backward(ctx, grad_output, grad_lse, grad_S):
        q, k, v, lse, S = ctx.saved_tensors
        return my_kernel_bwd(q, k, v, lse, S, grad_output)

def my_op(q, k, v):
    out, _lse, _S = MyOp.apply(q, k, v)  # drop intermediates at the call site
    return out

What happens under torch.compile?

When torch.compile encounters MyOp.apply(...), Dynamo traces inside the
forward and backward methods. Whether this succeeds depends entirely on what
those methods contain:

  • Pure PyTorch ops: Dynamo traces through them, everything compiles. No graph
    break.
  • Numpy calls: Dynamo converts them to PyTorch equivalents and compiles through.
  • Custom C++/CUDA kernel calls (via pybind11, torch.ops, cpp_extension, etc.):
    Dynamo graph-breaks because it can’t symbolically trace through the kernel –
    specifically, it can’t determine the output shapes and dtypes without a
    FakeTensor/Meta implementation.

Most real custom ops fall into the third category. When a graph break happens,
the compiled graph gets split into pieces around the custom op call, and the op
itself runs in eager mode. You get multiple compiled “frames” instead of one
unified graph, losing optimization opportunities.

API 2: torch.library.custom_op

The compile-friendly approach – register the op in PyTorch’s dispatcher:

@torch.library.custom_op("mylib::my_op", mutates_args=())
def _my_op(q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor, Tensor]:
    out, lse, S = my_kernel_fwd(q, k, v)
    return out, lse, S

@_my_op.register_fake
def _(q, k, v):
    # No real computation -- just describe the output shapes and dtypes.
    return (
        torch.empty_like(q),
        torch.empty(q.shape[0], q.shape[1], q.shape[2], device=q.device),
        torch.empty_like(q),
    )

def setup_context(ctx, inputs, output):
    q, k, v = inputs
    out, lse, S = output
    ctx.save_for_backward(q, k, v, lse, S)

def backward(ctx, grad_out, grad_lse, grad_S):
    q, k, v, lse, S = ctx.saved_tensors
    return my_kernel_bwd(q, k, v, lse, S, grad_out)

_my_op.register_autograd(backward, setup_context=setup_context)

def my_op(q, k, v):
    out, _lse, _S = _my_op(q, k, v)  # drop intermediates at the call site
    return out

Dynamo treats this as an opaque node in the graph (like torch.mm). It never
tries to trace inside. The register_fake implementation provides the shape/dtype
metadata needed for symbolic tracing. No graph break, guaranteed.

Answers to common questions

Q1: Does autograd.Function fall back to eager under torch.compile?

It depends on what’s inside forward(). Dynamo traces into the method body. If
the kernel call is opaque (no FakeTensor/Meta implementation), Dynamo graph-breaks
at that point. The surrounding code compiles, but the custom op itself runs in
eager.

You can verify this by counting compiled frames (see runnable example below).

Q2: Why register autograd separately?

Because torch.compile needs different information at different stages:

Stage What the compiler needs What provides it
Tracing (compile-time) Output shapes and dtypes register_fake
Execution (forward pass) The actual kernel custom_op function body
Differentiation (backward pass) Gradient formula register_autograd

autograd.Function bundles all three into one class. When Dynamo traces into
forward() and hits an opaque kernel, it can’t satisfy the tracing stage –
it doesn’t know what shapes come out. The custom_op API forces you to provide
this information explicitly via register_fake, which is what makes it
compile-friendly.

Q3: The forward doesn’t have ctx – how do I save intermediates?

Return them as additional outputs from the custom op. The setup_context
callback receives the complete forward output, so you can unpack and save
whatever you need:

@torch.library.custom_op("mylib::my_op", mutates_args=())
def my_op(x: Tensor) -> tuple[Tensor, Tensor]:
    result = torch.sin(x)
    intermediate = torch.cos(x)   # needed for backward, expensive to recompute
    return result, intermediate

def setup_context(ctx, inputs, output):
    result, intermediate = output  # <-- full forward output is available
    ctx.save_for_backward(intermediate)

setup_context runs inline during the forward pass, right after the op executes.
It is not rerunning forward – it just receives the output that was already
computed. No redundant work.

Runnable example

The example below demonstrates everything with simulated “custom kernels”
(registered C++ ops without Meta implementations, which is the situation custom
kernel authors typically face).

import torch
from torch import Tensor

# --- Simulate custom C++ kernels (registered ops without Meta impl) ---

_lib = torch.library.Library("demo", "DEF")
_lib.define("kernel_fwd(Tensor x) -> (Tensor, Tensor)")
_lib.impl("kernel_fwd", lambda x: (torch.sin(x), torch.cos(x)), "CPU")

_lib.define("kernel_bwd(Tensor cos_x, Tensor grad) -> Tensor")
_lib.impl("kernel_bwd", lambda cos_x, grad: grad * cos_x, "CPU")


# --- API 1: autograd.Function ---

class SinOpV1(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        out, cos_x = torch.ops.demo.kernel_fwd(x)
        ctx.save_for_backward(cos_x)
        return out

    @staticmethod
    def backward(ctx, grad):
        cos_x, = ctx.saved_tensors
        return torch.ops.demo.kernel_bwd(cos_x, grad)

def sin_v1(x):
    return SinOpV1.apply(x)


# --- API 2: custom_op ---

@torch.library.custom_op("demo::sin_op_v2", mutates_args=())
def sin_v2(x: Tensor) -> tuple[Tensor, Tensor]:
    return torch.ops.demo.kernel_fwd(x)

@sin_v2.register_fake
def _(x):
    return torch.empty_like(x), torch.empty_like(x)

def _setup_ctx(ctx, inputs, output):
    _out, cos_x = output
    ctx.save_for_backward(cos_x)

def _backward(ctx, grad_out, _grad_cos):
    cos_x, = ctx.saved_tensors
    return torch.ops.demo.kernel_bwd(cos_x, grad_out)

sin_v2.register_autograd(_backward, setup_context=_setup_ctx)


# --- Compare under torch.compile ---

frame_count = 0
def counting_backend(gm, example_inputs):
    global frame_count
    frame_count += 1
    return gm

x = torch.randn(8, requires_grad=True)

# API 1
frame_count = 0
f1 = torch.compile(lambda x: sin_v1(x).sum(), backend=counting_backend)
loss1 = f1(x)
loss1.backward()
grad1 = x.grad.clone()
print(f"autograd.Function: {frame_count} frame(s) {'(graph broke!)' if frame_count > 1 else ''}")

# API 2
x.grad = None
frame_count = 0
f2 = torch.compile(lambda x: sin_v2(x)[0].sum(), backend=counting_backend)
loss2 = f2(x)
loss2.backward()
grad2 = x.grad.clone()
print(f"custom_op:          {frame_count} frame(s)")

# Verify correctness
print(f"\nGradients match: {torch.allclose(grad1, grad2)}")
print(f"Reference (cos):    {torch.cos(x[:4]).tolist()}")
print(f"API 1 grad:         {grad1[:4].tolist()}")
print(f"API 2 grad:         {grad2[:4].tolist()}")

Expected output:

autograd.Function: 2 frame(s) (graph broke!)
custom_op:          1 frame(s)

Gradients match: True
Reference (cos):    [...]
API 1 grad:         [...]        # same values
API 2 grad:         [...]        # same values

Both produce correct gradients, but custom_op compiles as a single graph while
autograd.Function graph-breaks around the opaque kernel call.

Summary

autograd.Function torch.library.custom_op
Eager mode works works
torch.compile graph-breaks on opaque kernels no graph break
Forward intermediates ctx.save_for_backward() in forward return as extra outputs, save in setup_context
Autograd in the class register_autograd
Shape info for compiler not provided register_fake

Recommendation: Use torch.library.custom_op for custom kernels that need to
work with torch.compile. The API is more explicit – you tell the compiler exactly
what shapes your op produces (register_fake), how to differentiate it
(register_autograd), and what it computes (the function body). This separation is
what makes graph-break-free compilation possible. Beyond compile, going through the
dispatcher also gives you multiple backend registration (CPU, CUDA, XPU, etc. from
a single op definition), faster inference_mode dispatch, and correct interaction
with other PyTorch subsystems like vmap.

Why you should never save non-input/non-output tensors

A common mistake with autograd.Function is saving intermediate tensors (ones that
are neither inputs nor outputs) via ctx.save_for_backward:

# BAD: lse and S are intermediates, not inputs or outputs
class MyOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v):
        out, lse, S = my_kernel_fwd(q, k, v)
        ctx.save_for_backward(q, k, v, lse, S)  # lse, S are neither input nor output!
        return out

This breaks several things:

  1. Double backward (higher-order gradients) breaks. If you need to
    differentiate through the backward pass (e.g., for second-order optimization or
    torch.autograd.grad with create_graph=True), the saved intermediates must
    carry proper autograd metadata. Only tensors that are inputs or outputs of the
    autograd.Function have this metadata. Intermediates stashed in ctx that are
    neither input nor output lack the autograd graph linkage needed to propagate
    gradients through the backward pass itself.

  2. torch.compile tracing fails. When Dynamo traces inside the
    autograd.Function, it builds a graph of the forward. Intermediates stashed in
    ctx but not returned become invisible to the graph – they are side effects.
    The compiler cannot reason about them, leading to incorrect behavior or errors.

  3. Incorrect memory accounting. The autograd engine tracks the lifecycle of saved
    tensors based on input/output relationships. Tensors that are neither input nor
    output escape this tracking, potentially keeping large buffers alive longer than
    expected.

The fix: Return intermediates as additional outputs and drop them at the call
site:

# GOOD: lse and S are outputs, dropped by the wrapper
class MyOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v):
        out, lse, S = my_kernel_fwd(q, k, v)
        ctx.save_for_backward(q, k, v, lse, S)  # all are inputs or outputs now
        return out, lse, S

    @staticmethod
    def backward(ctx, grad_output, grad_lse, grad_S):
        q, k, v, lse, S = ctx.saved_tensors
        return my_kernel_bwd(q, k, v, lse, S, grad_output)

def my_op(q, k, v):
    out, _lse, _S = MyOp.apply(q, k, v)
    return out

This pattern makes the intermediates visible to the autograd engine, activation
checkpointing, and torch.compile. The custom_op API naturally encourages this
pattern since setup_context already receives the full output.

2 Likes