Hi,
I was wondering if there is a mechanism to export a model with torch.export
that contains a torch.library.custom_op that calls a CUDA kernel.
Ideally, I would like to use a python dsl (like triton or cute dsl) if possible, and could rewrite it in C++ if no other option.
So far, it can export the model and I imagine the only requirement is re-register my custom_op before running it in another python instance. However, my goal is running inference in C++.
AOT-inductor is the recommended way to export a model for later on using C++, right?.
The question(s) is the following:
- I believe such DSL kernels can be traced with a CUDA graph. Is it possible to hint torch.export or custom_op to record a CUDA graph of my DSL kernel? As far as I could see, libtorch or e.g. TensorRT can run cuda graphs.
- If 1 is not possible, would this work straight away if the custom_op is done with CUDA C++ as described in Custom C++ and CUDA Operators — PyTorch Tutorials 2.9.0+cu128 documentation As this tutorial is mainly meant for
torch.compile
I wrote a toy script.
"""
File: playground/cute_cuda_torch_export.py
Author: Juan Montesinos
Description:
Example of an attemp to compile a simple CUTE DSL kernel that adds all elements
of a tensor and exports the result back to PyTorch.
"""
from math import prod
from typing import Callable
import torch
from cutlass import cute
from cutlass.cute.runtime import from_dlpack
# toy kernel to compile
@cute.kernel
def add_all_kernel(input: cute.Tensor, output: cute.Tensor):
tx = cute.arch.thread_idx()[0]
bx = cute.arch.block_idx()[0]
bdim_x = cute.arch.block_dim()[0]
tx_g = bdim_x * bx + tx + 0
acc = cute.Float64(0.0)
if tx_g == 0:
for i in range(input.shape[0]):
acc += input[i]
output[0] = acc.to(cute.Float32)
@cute.jit
def add_all(input: cute.Tensor, output: cute.Tensor):
N = prod(input.shape)
flat_layout = cute.make_layout((N,), stride=(1,))
new_input = cute.make_tensor(input.iterator, flat_layout)
kernel = add_all_kernel(new_input, output)
kernel.launch(grid=(1, 1, 1), block=(1, 1, 1))
def compile_add_all(input_t: torch.Tensor) -> Callable:
sum_t = torch.zeros(1, dtype=torch.float32).to(input_t.device)
sum_cu = from_dlpack(sum_t, assumed_align=16)
input_cu = from_dlpack(input_t, assumed_align=16)
compiled_kernel = cute.compile(add_all, input_cu, sum_cu)
return compiled_kernel
def add_all_torch(input_t: torch.Tensor, kernel: Callable) -> torch.Tensor:
sum_t = torch.zeros(1, dtype=input_t.dtype).to(input_t.device)
sum_cu = from_dlpack(sum_t, assumed_align=16)
input_cu = from_dlpack(input_t, assumed_align=16)
kernel(input_cu, sum_cu)
return sum_t
tensor = torch.rand(12, 5, 6, 10, dtype=torch.float32).cuda().contiguous()
result_gt = tensor.sum()
# Warn-up to compile the kernel
kernel_compiled = compile_add_all(tensor)
result = add_all_torch(tensor, kernel_compiled)
# Register the kernel as a custom op
@torch.library.custom_op("cute::add_all", mutates_args=(), device_types="cuda")
def add_all_op(input_t: torch.Tensor) -> torch.Tensor:
return add_all_torch(input_t, kernel_compiled)
@add_all_op.register_fake
def _(input_t):
return torch.empty(1, dtype=input_t.dtype)
# Check correctness
result = add_all_op(tensor)
diff = result_gt - result
print(f"The difference is {diff} as")
# Run torch.export
class AddAllMod(torch.nn.Module):
def forward(self, input_t: torch.Tensor) -> torch.Tensor:
return add_all_op(input_t)
model = AddAllMod()
print("Exporting the custom op...")
exported_mod = torch.export.export(model, (tensor,))
result_exported = exported_mod.module()(tensor)
diff_exported = result_gt - result_exported
print(f"The difference from exported function is {diff_exported.item()}")
torch.export.save(exported_mod, "exported_program.pt2")
print('Exported successfully!')
# Try aot compile
exported_aot = torch._inductor.aot_compile(exported_mod,(tensor,))
