[RFC] Adding Triton Backend for Aten operators

Abstract

This RFC discusses

  1. the benefits and challenges of developing dispatch functions for Aten operators in Triton.

  2. a practice in adding Triton backend functions to Aten operators.

Motivation

Pytorch now has 2600+ entries in native_functions.yaml (as of this RFC), which grows about 600 compared to what was report at this post. The large number of operators poses a challenge for GPU vendors or other hardware manufacturers who want to add a new backend for pytorch. An op-by-op implementation is needed for a comprehensive operator support.

There are some efforts on decomposing pytorch operators into a smaller set of operators or lowering pytorch Aten IR into other IRs with a smaller set of operator like FuncTorch, PrimTorch and torch inductor. However these methods either work in a trace-and-compile manner, or are too slow to run in eager mode. They sit above the pytorch dispatcher and are very different from the eager execution mode.

We propose an alternative solution to the problem by writing operators in a language with multi-backend support. In short, it is a Write once, compile anywhere method. Kernels are written once and compiled into different binaries with different compiler backends. This solution is orthogonal to the solutions listed above. Instead of reducing the number of operators by decomposing operatos to a small set, it tries to offload multi-backend support to the compiler. Since it integrates into pytorch in Aten disptacher, it works with eager execution seamlessly.

With higher abstraction level (tile-oriented, CTA level programming) and decent performance, Triton is getting more attention from deep learning developers. It is used in torch inductor as the code generation language, and also used in many training/inference libraries for LLMs(lightllm, vllm, unsloth) and kernel libraries(Liger Kernels).

In addition, Triton is open-source and supports multiple backends. Thus, it is getting more support from accelerator manufacturers, since supporting Triton makes their device a more attractive platform to developers and the industry as well.

Though Triton has been widely used to develop custom operators, there were not many attempts to write Triton kernels for standard operators like pointwise operation, reduction, normalization, gemms. Those standard kernels have been studied a lot and may have been implemented with some dedicated libraries in platform-specific programming languages. But developing these standard operators in Triton saves GPU manufacturers much efforts to support a wide variety of Aten operators, especially for those without a comprehensive and highly optimized software stack.

Ezyang made a proof-on-concept that Triton can be used to develop kernels for Aten operators before.

We created FlagGems, a high-performance general operator library implemented in Triton.

Background

Handwritten kernels vs compiler generated kernels

Torch inductor lowers torch IR into an internal IR with a very small set of opertors and generates Triton code for a wide variety of operators that are composed of pointwise, reduction and scan operations. We choose to create a library of handwritten kernels with the following considerations.

  1. Coverage: There are some operators that are not easily generated by torch-inductor, for example, sort, argsort, unique and topk. Handwritten kernels have broader coverage than generated ones.

  2. Performance: Handwritten kernels may have better performance than compiler generated kernels. Since inductor performs analysis and fusion at IRNode and SchedulerNode levels, where higher level semantics are lost, some optimizations with non-obvious mathematical transformations cannot be applied. For example, online softmax normalizer and flash attention, while we can apply these optimizations to handwritten kernels.

  3. Flexibility: In addition, it is more flexible to optimize handwritten kernels than to optimize a full torch.compile stack. Since torch.compile has several components like operation decomposition, fusion and code generation and config selection, some of which are shared in the compilation of many operators. Modifying those components would have influence to multiple operators. While handwritten kernels may also have some shared components, it is relatively easy to make changes to only some operators or patterns.

Design

Overview

The overview of FlagGems is shown above. FlagGems consists of the kernels (jit functions, actually) written in Triton and several componets for them to work with DL frameworks and device APIs.

  1. Pytorch Integration describes how FlagGems integrates into pytorch.

  2. Wrapper describes the structure of wrapper code above launching of jit functions, and the challenges to the efficiency and portability of wrapper code.

  3. Jit function discusses how to write efficient jit functions and how to efficiently find reasonable kernel configs. This part also shows how we deal with Triton lang’s limitations on writing jit functions for arbitrarily ranked tensors.

  4. Triton runtime discusses how to reduce runtime overhead of Triton.

  5. Triton compilers describes how FlagGems work with multiple backends.

Pytorch Integration

FlagGems integrates into pytorch by registering wrapper functions into the dispatcher. Since the project is started earlier than torch.library.custom_op come into being, low level APIs torch.Library.impl is used.

  1. We create a wrapper function with the same signature as the corresponding ATen native function’s signature.

  2. The wrapper function is registered into pytorch’s dispatcher with torch.Library APIs, overwriting existing implementations for cuda backend.

For standard operators, we register our implementations following some rules to work with autograd:

  1. For a native function with backend-specific implementations, we override the ‘CUDA’ backend, then the backward implementation defined in Aten retains.

  2. For a native function without backend-specific implementations and only CompositeExplicitAutograd, we register a function for ‘CUDA’ backend, which has higher priority than CompositeExplicitAutograd. The backward implementation defined in Aten also retains.

  3. For a native function without backend-specific implementations and only CompositeImplicitAutograd, which has no explicitly defined backward pass(its backward pass is composed of the backward passes of the operators used in this implementation), we have to provide both the forward pass and backward pass, which are wrapped into a torch.autograd.Function, and registered with AutogradCUDA key.

The rule of thumb is: Only register with CUDA or AutogradCUDA key. Only use the latter when we want or have to make our backward pass used in autograd.

To work with torch.compile, we need to find a way to make sure that overriding the dispatch does not break the dispatch function for FakeTensor. We currently find that Aten operators with AutogradCUDA implementation added/overridden fail to work correctly with graph tracing in dynamo.

For custom ops, we follow the python cutom op guid in pytorch.

Multi-backend Support

FlagGems depends on Triton, but the Triton package differs on different types of devices. For example, when working on AMD gpus, FlagGems need to work with Triton and torch for rocm devices.

To support multiple backend without forking the project for different backends, there are some requirements.

  1. The dispatch key to register may differ for different backends. A method for device and backend detection is required to get the desired dispatchkey to register.

  2. The APIs used in wrapper functions should be device-agnostic. For example, APIs for creating new empty tensors with specified device or switching device context. We may use device agnostic python runtime APIs in the future, as the RFC proposed, or wrap our own.

An alternative method is to introduce some second layer Dispatcher, and leave dispatching to different devices to that dispatcher.

Wrapper

In many cases, wrappers are simple. A convention wrapper includes the following tasks:

  1. input argument checking and error handling;

  2. meta data inference and allocation of outputs;

  3. device context switch and call to the jit functions;

  4. return of the results.

But these tasks (except for the call to the jit function) may also have significant overhead, especially for operations on small tensor, mainly due to the following reasons.

  1. When the kernel does support tensor arbitrary strided layout, we need to copy the input tensor as contiguous, which introduces extra data copying;

  2. Meta data inference of outpus may be expensive. Though the device, dtype and shape of outputs are well defined for most operators, the stride orders of the outputs are flexible. Selecting the best stride order of outputs could be expensive, especially when implemented in Python.

  3. Selecting the best algorithm may be expensive. Some wrappers may call different jit functions according to input sizes, data layout, contiguity, etc.

Although there are workarounds to bypass the overhead, for example, using CUDAGraph, we are working on reducing the per-invocation CPU overhead, since performance in eager execution matters.

In addition, we may add another layer of indirection by providing a series of wrappers that do not necessarily follow the signature of Aten functions, which serves as the interface of a framework-neutral math library.

Code Reuse in Jit functions

Developing Triton jit functions is basically kernel programming. We discuss code reuse in developing jit functions in FlagGems below.

Pointwise operation, which seems be the simplest operations to write in Triton and always serves as the first example on Triton programming, are however, hard to be flexible with respect to the the input tensor’s rank and layout, mainly due to the limitations of Triton lang’s features.

  1. It is not easy to write a sigle Triton jit function that support arbitrarily-ranked torch tensors with strided layout. CUDA kernels for some ATen operators support arbitrarily-ranked torch tensors either by making the inputs contiguous beforehand or by using capturing tensor strides of up to 25 ranks, using cuda lambda capturing. Enforcing contiguous input or copying them as contiguous can also be done in python. However, Triton jit function has no similar mechanism to capture tensors’ strides.

  2. Triton does not support self-defined struct or static array as parameter to jit functions. So we add a runtime code generation mechanism to generate Triton code for pointwise operators according to the rank of input tensors. Many other operators like Pad, gather scatter, kron have similar issues.

  3. Triton jit function has no support for higher order functions. While all the pointwise operations have similar logic, there is no way to pass the computation on scalars as a lambda into Triton jit function as a parameter. That’s another reason why we use code generation for pointwise operators.

Here’s an example of how to use it.

@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")])
@triton.jit
def true_div_func(x, y):
    return x / y

The function decorated by pointwise_dynamic supports arbitrary rank, arbitrary strides and broadcasting.

There is a similar/related solution to this problem in xformers. It works by adding an extension to the syntax and parser of the Triton language, so as to support vararg. The parameters or variables with VAR_ARGS_ARRAY annotation would be unrolled in the extension to instantiate a desired jit function. This solution requires the user to write templates and instantiate them explicitly. However, it does not support varrag of constexpr, and does not consider the problem of supporting higher order functions.

Here’s an example of how to use it.

@triton.jit
def weighted_sumN(
    output_ptr,
    a_ptr: "VAR_ARGS_ARRAY",  # type: ignore # noqa: F821
    b: "VAR_ARGS_ARRAY",  # type: ignore # noqa: F821
    BLOCK_SIZE: tl.constexpr,
):
    # Weighted sum, where the weights are on CPU
    offset = tl.arange(0, BLOCK_SIZE)
    output = tl.zeros([BLOCK_SIZE], tl.float32)
    for i in range(len(a_ptr)):
        output = output + tl.load(a_ptr[i] + offset) * b[i]
    tl.store(output_ptr + offset, output)
...
kernel = unroll_varargs(weighted_sumN, N=NUM_INPUTS)
kernel[(1,)](output, *a, *b_list, BLOCK_SIZE=32)

There are also other code reuse in developing a library of operators, for example, common patterns finding best configs. But we do cover them here.

Triton Compile

Triton can be used in aot-compile or jit-compile manner, though it is mainly used in a jit-compile manner. The pros and cons for aot and jit way of using Triton in operator libraries are:

aot

Pros:

  1. faster cold launch since no autotunning at runtime.

  2. no dependency on python runtime, better for deployment.

Cons:

  1. large compiled artifact and longer build time;

  2. extra development of a build system including enumerating all possible kernels of each Triton jit function, tuning of jit functions, packaging of all the generated kernels and a dispatcher to select kernels at runtime.

jit

Pros:

  1. simpler workflow for development and packaging;

Cons:

  1. slow warmup when autotunning is involved.

  2. dependency on python runtime.

We currently choose jit-compile for simpler development and packaging. When revisiting this topic from the deployment, we may try some aot-based methods.

Related projects:

  1. aottriton: a project by ROCm that employs aot compilation. It re-implements an almost full-fledged build system and a runtime to use Triton in an aot manner to build an operator library.

  2. AOTInductor: a project in pytorch to compile exported models Ahead-of-Time, to a shared library that can be run in a non-Python environment.

Triton Runtime

There has been several issues and improvements for Triton’s runtime of jit functions.

  1. [FRONTEND][RFC] Low latency kernel launching

  2. Faster jit

  3. Even faster jit

JIT Runtime

The main reason for slow jit runtime is that Triton launches all the jit functions in a boxed way. Instead of simply calling an ordinary function, calling a jit function involves the following tasks at each invocation, which are slow.

  1. Parameter binding and routing. The signature of the jit function is analyzed beforehand, but parameters received at runtime must be bound to the parameters defined in the signature, which are then classified and routed to different sets. Some are constexpr parameters or compilation parameters (e.g. num_warps, num_stages, num_ctas) while others are parameters to the compiled kernel.

  2. Cache key computation. Triton runtime extracts features from input parameters to specify kernels by checking input parameters’ dtypes, values, or divisibilities of integers and pointers by some predefined values, for example, 16.

Other features like heuristics, autotunner and hook make jit runtime more complicated. Recent refactoring of the jit runtime has reduced the runtime overhead, but there is still space to optimize the runtime performance.

LibEntry

We also have a faster runtime(LibEntry) in FlagGems. We observe that going through the autotuning and heuristics logics is compulsory for a jit function even if the kernel is already compiled. To reduce this extra overhead, we install a fast track lookup function at the entrance of each Triton JIT function, called LibEntry. LibEntry caches a mapping from input parameters to the compiled kernels for each jit function. If the LibEntry cache is hit, the jit function will skip ahead and run the saved kernel, bypassing the intermediate Autotuner and Heuristics. Otherwise, the control will fall back to the original Triton runtime.

There are also other projects working on persisting the results of Autotunner (an in-memory cache) into disks, which are restored and reused in later running.

We are working on combing these methods to reduce runtime overhead and reuse tunning results.

Related projects:

  1. triton-dejavu: a project by IBM to reduce autotune overhead of Triton jit functions by storing and restoring autotunner results using JSON files, which prevents auto-tunning at each run.

Summary

In this RFC we discuss the benefits and challenges to developing Aten operators in Triton and our practice in FlagGems.

Benefits:

  1. Triton has higher level of sbstraction and is easy to learn;

  2. Triton compiler delivers decent performance;

  3. Triton has native multi-backend support and is gaining more support from many GPU manufacturers.

Challenges:

  1. Triton is mainly used in a Just-in-Time fashion, and its compiler and runtime depends on python. So it is not straightforward to create a library with triton jit functions without dependency on python (especially when it is used in deployment).

  2. Triton lang has limited features, making it harder to reuse code in kernel development;

  3. Triton’s runtime has high runtime overhead;

  4. Implementing wrapper functions in Python has higher runtime overhead than in C++;

  5. It takes some effort to override Aten operators’ dispatch function in python and make it work with other Pytorch subsystems, for example torch.autograd and torch.compile.

Our goal is to develop a high-performance operator library in triton with multi-backend support in a single-source manner, and to integrate it into Pytorch by using it in implementing Aten operators.