`torch.compile` `AOTAutograd` backwards _inductor function

How can I get a runnable copy of the backwards graph when running torch.compile?

For example:

@torch.compile
def f(x):
    out = x.sin() + x.cos()
    return out

x = torch.ones(2, requires_grad=True).cuda()
out = f(x)

If I run this with TORCH_LOGS=all, I can see that the backwards graph is output:

===== Backward graph 0 =====
 <eval_with_key>.36 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[2]", tangents_1: "f32[2]"):
        # File: /notebooks/clean-repos/triton-autodiff/test_compile.py:32 in f, code: tmp1 = x.sin() + x.cos()
        sin: "f32[2]" = torch.ops.aten.sin.default(primals_1)
        cos: "f32[2]" = torch.ops.aten.cos.default(primals_1);  primals_1 = None
        neg: "f32[2]" = torch.ops.aten.neg.default(sin);  sin = None
        mul: "f32[2]" = torch.ops.aten.mul.Tensor(tangents_1, neg);  neg = None
        mul_1: "f32[2]" = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None
        
        # File: /notebooks/clean-repos/triton-autodiff/test_compile.py:32 in f, code: tmp1 = x.sin() + x.cos()
        add_1: "f32[2]" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None
        return [add_1]

However, if I set TORCH_COMPILE_DEBUG=1 and look at the inductor cache, I see only the forward kernel (output.py):


from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: torch_compile/cache/rv/crviokfbw7lr4fi4yf2kqhtyjostdohc32ulg7vlfbyprvjinxaa.py
# Source Nodes: [cos, sin, tmp1], Original ATen: [aten.add, aten.cos, aten.sin]
# cos => cos
# sin => sin
# tmp1 => add
triton_poi_fused_add_cos_sin_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers

@pointwise(
    size_hints=[2], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cos_sin_0', 'mutated_arg_names': [], 'no_x_dim': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.sin(tmp0)
    tmp2 = tl.cos(tmp0)
    tmp3 = tmp1 + tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''')
...

How can I get the equivalent for the backwards graph, and moreover the torch.autograd.Function that is created during compilation?