When integrating custom kernels (CUDA, Triton, etc.) into a PyTorch training loop,
you need to tell PyTorch how to run the kernel and how to differentiate it.
Two APIs exist for this, and they interact with torch.compile very differently.
For the full reference, see the Custom Ops landing page and the Custom Ops manual.
API 1: torch.autograd.Function
Note: This is the most widely used API, but it is not the recommended
one if you needtorch.compileintegration. Most users should migrate to
torch.library.custom_opdescribed below.
The classic approach – define forward and backward in a single class:
class MyOp(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v):
out, lse, S = my_kernel_fwd(q, k, v)
ctx.save_for_backward(q, k, v, lse, S)
return out, lse, S # intermediates must be returned -- see below
@staticmethod
def backward(ctx, grad_output, grad_lse, grad_S):
q, k, v, lse, S = ctx.saved_tensors
return my_kernel_bwd(q, k, v, lse, S, grad_output)
def my_op(q, k, v):
out, _lse, _S = MyOp.apply(q, k, v) # drop intermediates at the call site
return out
What happens under torch.compile?
When torch.compile encounters MyOp.apply(...), Dynamo traces inside the
forward and backward methods. Whether this succeeds depends entirely on what
those methods contain:
- Pure PyTorch ops: Dynamo traces through them, everything compiles. No graph
break. - Numpy calls: Dynamo converts them to PyTorch equivalents and compiles through.
- Custom C++/CUDA kernel calls (via pybind11,
torch.ops, cpp_extension, etc.):
Dynamo graph-breaks because it can’t symbolically trace through the kernel –
specifically, it can’t determine the output shapes and dtypes without a
FakeTensor/Meta implementation.
Most real custom ops fall into the third category. When a graph break happens,
the compiled graph gets split into pieces around the custom op call, and the op
itself runs in eager mode. You get multiple compiled “frames” instead of one
unified graph, losing optimization opportunities.
API 2: torch.library.custom_op
The compile-friendly approach – register the op in PyTorch’s dispatcher:
@torch.library.custom_op("mylib::my_op", mutates_args=())
def _my_op(q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor, Tensor]:
out, lse, S = my_kernel_fwd(q, k, v)
return out, lse, S
@_my_op.register_fake
def _(q, k, v):
# No real computation -- just describe the output shapes and dtypes.
return (
torch.empty_like(q),
torch.empty(q.shape[0], q.shape[1], q.shape[2], device=q.device),
torch.empty_like(q),
)
def setup_context(ctx, inputs, output):
q, k, v = inputs
out, lse, S = output
ctx.save_for_backward(q, k, v, lse, S)
def backward(ctx, grad_out, grad_lse, grad_S):
q, k, v, lse, S = ctx.saved_tensors
return my_kernel_bwd(q, k, v, lse, S, grad_out)
_my_op.register_autograd(backward, setup_context=setup_context)
def my_op(q, k, v):
out, _lse, _S = _my_op(q, k, v) # drop intermediates at the call site
return out
Dynamo treats this as an opaque node in the graph (like torch.mm). It never
tries to trace inside. The register_fake implementation provides the shape/dtype
metadata needed for symbolic tracing. No graph break, guaranteed.
Answers to common questions
Q1: Does autograd.Function fall back to eager under torch.compile?
It depends on what’s inside forward(). Dynamo traces into the method body. If
the kernel call is opaque (no FakeTensor/Meta implementation), Dynamo graph-breaks
at that point. The surrounding code compiles, but the custom op itself runs in
eager.
You can verify this by counting compiled frames (see runnable example below).
Q2: Why register autograd separately?
Because torch.compile needs different information at different stages:
| Stage | What the compiler needs | What provides it |
|---|---|---|
| Tracing (compile-time) | Output shapes and dtypes | register_fake |
| Execution (forward pass) | The actual kernel | custom_op function body |
| Differentiation (backward pass) | Gradient formula | register_autograd |
autograd.Function bundles all three into one class. When Dynamo traces into
forward() and hits an opaque kernel, it can’t satisfy the tracing stage –
it doesn’t know what shapes come out. The custom_op API forces you to provide
this information explicitly via register_fake, which is what makes it
compile-friendly.
Q3: The forward doesn’t have ctx – how do I save intermediates?
Return them as additional outputs from the custom op. The setup_context
callback receives the complete forward output, so you can unpack and save
whatever you need:
@torch.library.custom_op("mylib::my_op", mutates_args=())
def my_op(x: Tensor) -> tuple[Tensor, Tensor]:
result = torch.sin(x)
intermediate = torch.cos(x) # needed for backward, expensive to recompute
return result, intermediate
def setup_context(ctx, inputs, output):
result, intermediate = output # <-- full forward output is available
ctx.save_for_backward(intermediate)
setup_context runs inline during the forward pass, right after the op executes.
It is not rerunning forward – it just receives the output that was already
computed. No redundant work.
Runnable example
The example below demonstrates everything with simulated “custom kernels”
(registered C++ ops without Meta implementations, which is the situation custom
kernel authors typically face).
import torch
from torch import Tensor
# --- Simulate custom C++ kernels (registered ops without Meta impl) ---
_lib = torch.library.Library("demo", "DEF")
_lib.define("kernel_fwd(Tensor x) -> (Tensor, Tensor)")
_lib.impl("kernel_fwd", lambda x: (torch.sin(x), torch.cos(x)), "CPU")
_lib.define("kernel_bwd(Tensor cos_x, Tensor grad) -> Tensor")
_lib.impl("kernel_bwd", lambda cos_x, grad: grad * cos_x, "CPU")
# --- API 1: autograd.Function ---
class SinOpV1(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
out, cos_x = torch.ops.demo.kernel_fwd(x)
ctx.save_for_backward(cos_x)
return out
@staticmethod
def backward(ctx, grad):
cos_x, = ctx.saved_tensors
return torch.ops.demo.kernel_bwd(cos_x, grad)
def sin_v1(x):
return SinOpV1.apply(x)
# --- API 2: custom_op ---
@torch.library.custom_op("demo::sin_op_v2", mutates_args=())
def sin_v2(x: Tensor) -> tuple[Tensor, Tensor]:
return torch.ops.demo.kernel_fwd(x)
@sin_v2.register_fake
def _(x):
return torch.empty_like(x), torch.empty_like(x)
def _setup_ctx(ctx, inputs, output):
_out, cos_x = output
ctx.save_for_backward(cos_x)
def _backward(ctx, grad_out, _grad_cos):
cos_x, = ctx.saved_tensors
return torch.ops.demo.kernel_bwd(cos_x, grad_out)
sin_v2.register_autograd(_backward, setup_context=_setup_ctx)
# --- Compare under torch.compile ---
frame_count = 0
def counting_backend(gm, example_inputs):
global frame_count
frame_count += 1
return gm
x = torch.randn(8, requires_grad=True)
# API 1
frame_count = 0
f1 = torch.compile(lambda x: sin_v1(x).sum(), backend=counting_backend)
loss1 = f1(x)
loss1.backward()
grad1 = x.grad.clone()
print(f"autograd.Function: {frame_count} frame(s) {'(graph broke!)' if frame_count > 1 else ''}")
# API 2
x.grad = None
frame_count = 0
f2 = torch.compile(lambda x: sin_v2(x)[0].sum(), backend=counting_backend)
loss2 = f2(x)
loss2.backward()
grad2 = x.grad.clone()
print(f"custom_op: {frame_count} frame(s)")
# Verify correctness
print(f"\nGradients match: {torch.allclose(grad1, grad2)}")
print(f"Reference (cos): {torch.cos(x[:4]).tolist()}")
print(f"API 1 grad: {grad1[:4].tolist()}")
print(f"API 2 grad: {grad2[:4].tolist()}")
Expected output:
autograd.Function: 2 frame(s) (graph broke!)
custom_op: 1 frame(s)
Gradients match: True
Reference (cos): [...]
API 1 grad: [...] # same values
API 2 grad: [...] # same values
Both produce correct gradients, but custom_op compiles as a single graph while
autograd.Function graph-breaks around the opaque kernel call.
Summary
autograd.Function |
torch.library.custom_op |
|
|---|---|---|
| Eager mode | works | works |
torch.compile |
graph-breaks on opaque kernels | no graph break |
| Forward intermediates | ctx.save_for_backward() in forward |
return as extra outputs, save in setup_context |
| Autograd | in the class | register_autograd |
| Shape info for compiler | not provided | register_fake |
Recommendation: Use torch.library.custom_op for custom kernels that need to
work with torch.compile. The API is more explicit – you tell the compiler exactly
what shapes your op produces (register_fake), how to differentiate it
(register_autograd), and what it computes (the function body). This separation is
what makes graph-break-free compilation possible. Beyond compile, going through the
dispatcher also gives you multiple backend registration (CPU, CUDA, XPU, etc. from
a single op definition), faster inference_mode dispatch, and correct interaction
with other PyTorch subsystems like vmap.
Why you should never save non-input/non-output tensors
A common mistake with autograd.Function is saving intermediate tensors (ones that
are neither inputs nor outputs) via ctx.save_for_backward:
# BAD: lse and S are intermediates, not inputs or outputs
class MyOp(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v):
out, lse, S = my_kernel_fwd(q, k, v)
ctx.save_for_backward(q, k, v, lse, S) # lse, S are neither input nor output!
return out
This breaks several things:
-
Double backward (higher-order gradients) breaks. If you need to
differentiate through the backward pass (e.g., for second-order optimization or
torch.autograd.gradwithcreate_graph=True), the saved intermediates must
carry proper autograd metadata. Only tensors that are inputs or outputs of the
autograd.Functionhave this metadata. Intermediates stashed inctxthat are
neither input nor output lack the autograd graph linkage needed to propagate
gradients through the backward pass itself. -
torch.compiletracing fails. When Dynamo traces inside the
autograd.Function, it builds a graph of the forward. Intermediates stashed in
ctxbut not returned become invisible to the graph – they are side effects.
The compiler cannot reason about them, leading to incorrect behavior or errors. -
Incorrect memory accounting. The autograd engine tracks the lifecycle of saved
tensors based on input/output relationships. Tensors that are neither input nor
output escape this tracking, potentially keeping large buffers alive longer than
expected.
The fix: Return intermediates as additional outputs and drop them at the call
site:
# GOOD: lse and S are outputs, dropped by the wrapper
class MyOp(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v):
out, lse, S = my_kernel_fwd(q, k, v)
ctx.save_for_backward(q, k, v, lse, S) # all are inputs or outputs now
return out, lse, S
@staticmethod
def backward(ctx, grad_output, grad_lse, grad_S):
q, k, v, lse, S = ctx.saved_tensors
return my_kernel_bwd(q, k, v, lse, S, grad_output)
def my_op(q, k, v):
out, _lse, _S = MyOp.apply(q, k, v)
return out
This pattern makes the intermediates visible to the autograd engine, activation
checkpointing, and torch.compile. The custom_op API naturally encourages this
pattern since setup_context already receives the full output.