Torch.export with aot-indutor for cute dsl kernels

Hi,

I was wondering if there is a mechanism to export a model with torch.export
that contains a torch.library.custom_op that calls a CUDA kernel.

Ideally, I would like to use a python dsl (like triton or cute dsl) if possible, and could rewrite it in C++ if no other option.
So far, it can export the model and I imagine the only requirement is re-register my custom_op before running it in another python instance. However, my goal is running inference in C++.
AOT-inductor is the recommended way to export a model for later on using C++, right?.

The question(s) is the following:

  1. I believe such DSL kernels can be traced with a CUDA graph. Is it possible to hint torch.export or custom_op to record a CUDA graph of my DSL kernel? As far as I could see, libtorch or e.g. TensorRT can run cuda graphs.
  2. If 1 is not possible, would this work straight away if the custom_op is done with CUDA C++ as described in Custom C++ and CUDA Operators — PyTorch Tutorials 2.9.0+cu128 documentation As this tutorial is mainly meant for torch.compile

I wrote a toy script.

"""
File: playground/cute_cuda_torch_export.py
Author: Juan Montesinos

Description:
    Example of an attemp to compile a simple CUTE DSL kernel that adds all elements
    of a tensor and exports the result back to PyTorch.
"""

from math import prod
from typing import Callable

import torch
from cutlass import cute
from cutlass.cute.runtime import from_dlpack


# toy kernel to compile
@cute.kernel
def add_all_kernel(input: cute.Tensor, output: cute.Tensor):
    tx = cute.arch.thread_idx()[0]
    bx = cute.arch.block_idx()[0]
    bdim_x = cute.arch.block_dim()[0]

    tx_g = bdim_x * bx + tx + 0
    acc = cute.Float64(0.0)
    if tx_g == 0:
        for i in range(input.shape[0]):
            acc += input[i]
        output[0] = acc.to(cute.Float32)


@cute.jit
def add_all(input: cute.Tensor, output: cute.Tensor):
    N = prod(input.shape)
    flat_layout = cute.make_layout((N,), stride=(1,))
    new_input = cute.make_tensor(input.iterator, flat_layout)
    kernel = add_all_kernel(new_input, output)
    kernel.launch(grid=(1, 1, 1), block=(1, 1, 1))


def compile_add_all(input_t: torch.Tensor) -> Callable:
    sum_t = torch.zeros(1, dtype=torch.float32).to(input_t.device)
    sum_cu = from_dlpack(sum_t, assumed_align=16)
    input_cu = from_dlpack(input_t, assumed_align=16)
    compiled_kernel = cute.compile(add_all, input_cu, sum_cu)
    return compiled_kernel


def add_all_torch(input_t: torch.Tensor, kernel: Callable) -> torch.Tensor:
    sum_t = torch.zeros(1, dtype=input_t.dtype).to(input_t.device)
    sum_cu = from_dlpack(sum_t, assumed_align=16)
    input_cu = from_dlpack(input_t, assumed_align=16)
    kernel(input_cu, sum_cu)
    return sum_t


tensor = torch.rand(12, 5, 6, 10, dtype=torch.float32).cuda().contiguous()
result_gt = tensor.sum()


# Warn-up to compile the kernel
kernel_compiled = compile_add_all(tensor)

result = add_all_torch(tensor, kernel_compiled)


# Register the kernel as a custom op
@torch.library.custom_op("cute::add_all", mutates_args=(), device_types="cuda")
def add_all_op(input_t: torch.Tensor) -> torch.Tensor:
    return add_all_torch(input_t, kernel_compiled)


@add_all_op.register_fake
def _(input_t):
    return torch.empty(1, dtype=input_t.dtype)


# Check correctness
result = add_all_op(tensor)
diff = result_gt - result
print(f"The difference is {diff} as")

# Run torch.export
class AddAllMod(torch.nn.Module):
    def forward(self, input_t: torch.Tensor) -> torch.Tensor:
        return add_all_op(input_t)

model = AddAllMod()
print("Exporting the custom op...")
exported_mod = torch.export.export(model, (tensor,))
result_exported = exported_mod.module()(tensor)
diff_exported = result_gt - result_exported
print(f"The difference from exported function is {diff_exported.item()}")
torch.export.save(exported_mod, "exported_program.pt2")
print('Exported successfully!')

# Try aot compile
exported_aot = torch._inductor.aot_compile(exported_mod,(tensor,))

Hey!

@desertfire would have the best answer for this one.

As of today, if you use triton custom op, it will handled the same way other triton code generated by inductor and so it will be pre-compiled and bundled as pure binary. You can use that without any issue irrespective of the runtime.

AOT-inductor is the recommended way to export a model for later on using C++, right?.

Not sure what you mean by that? AOTI is if you want to get access to the inductor generated code in a pre-compiled fashion instead of jit.

For the other cases, it will depend on the C++ runtime you use unfortunately. The only one that is really suggested in core is via ExecuTorch which can delegate to AOTI if you want fused kernels. This is relatively new though.

Hi @albanD
thanks for your answer.

Not sure what you mean by that? AOTI is if you want to get access to the inductor generated code in a pre-compiled fashion instead of jit.

As far as I could see, in order to run in python-less environments, export with AOTI is a must, right?
For example, as per (Pytorch) Tensor-RT docs:

Similarily, Execu Torch docs shows compilation per backend already in python.

As of today, if you use triton custom op, it will handled the same way other triton code generated by inductor and so it will be pre-compiled and bundled as pure binary

Is it also the case for C++ CUDA Kernels?

Thanks for your support!

As @albanD pointed out, user-defined Triton kernels can work with the PT2 stack seamlessly, including AOTI, but the support of CuTe DSL hasn’t reached the same level yet. Registering a custom CUDA kernel is the right way to go, and make sure you follow Custom C++ and CUDA Operators — PyTorch Tutorials 2.9.0+cu128 documentation to register a FakeTensor kernel.