TorchInductor: a PyTorch-native Compiler with Define-by-Run IR and Symbolic Shapes

The PyTorch team has been building TorchDynamo, which helps to solve the graph capture problem of PyTorch with dynamic Python bytecode transformation. To actually make PyTorch faster, TorchDynamo must be paired with a compiler backend that converts the captured graphs into fast machine code. We have integrated numerous backends already, and built a lightweight autotuner to select the best backend for each subgraph.

Unfortunately a lot is lost in translation when exporting to different backends that differ greatly from PyTorch. Many involve multiple conversion steps, have a fundamentally different execution model than PyTorch, and have limited support for many PyTorch operators and features. Another key challenge is with a few exceptions, most backends are inference only and have made design decisions that make training support intractable. This export-based execution model will be important in a lot of specific applications, but PyTorch needs a native compiler with abstractions that closely mirror those of PyTorch.

In various updates, you have seen updates about our PyTorch-native compilers nvFuser and NNC.

In this post, we will introduce TorchInductor. TorchInductor is a new compiler for PyTorch, which is able to represent all of PyTorch and is built in a general way such that it will be able to support training and multiple backend targets. TorchInductor is able to represent aliasing and mutation by having the concept of TensorBox and StorageBox that map one-to-one with torch.Tensor and torch.Storage. It is able to handle views by having a symbocally strided tensor that maps directly from the native torch.Tensor stride representation, which makes views easy to handle. Other parts of PyTorch are handled similarly, by mirroring the data model of PyTorch in the backend. The design philosophy is a thin, easily hackable, way of symbolically mapping PyTorch to lower level backends and enabling rapid experimentation, autotuning between different backends, and higher level optimizations such as memory planning.

TorchInductor Design

TorchInductor is implemented in Python. There are pros and cons to this choice, but we have found this choice greatly increased velocity and developer productivity. We also have observed that the PyTorch community is much more likely to contribute to parts of PyTorch written in Python, and therefore it makes the system more approachable and hackable by our users.

To force the design of TorchInductor to be general, we are starting off with two lower level execution targets, that represent different points in the design space:

  • Triton is a new programming language that provides much higher productivity than CUDA, but with the ability to beat the performance of highly optimized libraries like cuDNN with clean and simple code. It is developed by Philippe Tillet at OpenAI, and is seeing enormous adoption and traction across the industry. Triton supports NVIDIA GPUs, and is quickly growing in popularity as a replacement for hand written CUDA kernels.
  • C++/OpenMP is a widely adopted specification for writing parallel kernels. OpenMP provides a work sharing parallel execution model, and enables support for CPUs. C++ is also an interesting target in that it is a highly portable language and could enable export to more exotic edge devices and hardware architectures.

The approach to building TorchInductor is a breadth-first one. We have spent most of our time make sure the core infrastructure is able to support the vast majority of PyTorch, including: aliasing/mutation/views, scatter (indirect writes), gather (indirect reads), pooling/windows/reductions, masked/conditional execution (padding, etc), template epilogue fusions, tiling, and horizontal/vertical fusions. So far we have not spent too much time optimizing any one pattern, but focused on general optimizations with widespread benefits.

TorchInductor supports dynamic shapes and strides using the SymPy symbolic math library. It specializes on zero and one, but for other unique tensor sizes it will assign them to a sympy.Symbol and flow them through the entire program. Memory loads and stores are represented directly as sympy indexing formulas based on the iteration variables and symbolic tensor sizes. TorchInductor will introduce guards when needed that lift any assumptions/requirements to the top of the subgraph and will trigger a recompile if those guards fail.

Another unique aspect of TorchInductor is it uses a define-by-run loop-level intermediate representation (IR). Many parts of the IR are Python callables that take SymPy expressions as inputs. We analyze this IR and do codegen by changing the implementation of ops.* and running the IR. As an example, the IR for x.permute(1, 0) + x[2, :] might be something like:

def inner_fn(index: List[sympy.Expr]):
    i1, i0 = index
    tmp0 = ops.load("x", i1 + i0*size1)
    tmp1 = ops.load("x", 2*size1 + i0)
    return ops.add(tmp0, tmp1)

torchinductor.ir.Pointwise(
    device=torch.device("cuda"),
    dtype=torch.float32,
    inner_fn=inner_fn,
    ranges=[size0, size1],
)

Where inner_fn defines how to compute a single element of the output buffer.

Training Performance Results



Full (and possibly more up to date if you are reading this in future) performance results can be found in the TorchDynamo Performance Dashboard. The dashboard contains NVIDIA A100 training speedups over eager for float32, float16, and automatic mixed precision (AMP). I will highlight just AMP results here:

Pass rate for AMP on a NVIDIA A100 GPU

+---------------------+------------+-------------+-------------+
|      Compiler       | torchbench | huggingface | timm_models |
+---------------------+------------+-------------+-------------+
|        eager        | 98%, 48/49 | 100%, 43/43 | 100%, 68/68 |
|     ts_nvfuser      | 92%, 45/49 | 95%, 41/43  | 74%, 50/68  |
|      aot_eager      | 86%, 42/49 | 100%, 43/43 | 97%, 66/68  |
|   aot_cudagraphs    | 84%, 41/49 | 84%, 36/43  | 94%, 64/68  |
|     aot_nvfuser     | 61%, 30/49 | 40%, 17/43  | 49%, 33/68  |
| inductor_cudagraphs | 71%, 35/49 | 86%, 37/43  | 72%, 49/68  |
+---------------------+------------+-------------+-------------+

Geometric mean speedup over eager (of passing models) for AMP on a NVIDIA A100 GPU

+---------------------+------------+-------------+-------------+
|      Compiler       | torchbench | huggingface | timm_models |
+---------------------+------------+-------------+-------------+
|        eager        |    1.0x    |    1.01x    |    1.0x     |
|     ts_nvfuser      |   1.04x    |    1.04x    |    1.03x    |
|      aot_eager      |    1.0x    |    1.0x     |    1.0x     |
|   aot_cudagraphs    |   1.17x    |    1.35x    |    1.06x    |
|     aot_nvfuser     |    1.2x    |    1.38x    |    1.32x    |
| inductor_cudagraphs |    1.9x    |    2.17x    |    1.69x    |
+---------------------+------------+-------------+-------------+

For TorchBench models, 35 out of 49 currently run correctly in training and of those working we see a 1.90x geomean speedup.

For Hugging Face models, 37 out of 43 models run correctly, providing a 2.17x geomean speedup.

For TIMM models, 49 out of 68 run correctly, providing a 1.69x geomean speedup.

For baselines to compare against, we have:

  • eager: baseline that runs the captured FX graph using PyTorch eager mode. This measures the overheads of TorchDynamo.
  • ts_nvfuser: nvFuser using its older TorchScript based backend
  • aot_eager: baseline that runs AOT Autograd using a PyTorch eager backend, to measure overheads of AOT Autograd.
  • aot_cudagraphs: An AOT Autograd backend that applies cudagraphas to reduce overheads
  • aot_nvfuser: nvFuser using its newer AOT Autograd based backend
  • inductor_cudagraphs: TorchInductor, as described in this post.

All of these backends use TorchDynamo as a frontend in this experiment, so they get the same graphs to optimize.

Conclusions

You can check out the TorchInductor source code. To use it you will need to install from source (or use nightlies) of TorchDynamo, Triton, and PyTorch. Once setup, you can try TorchInductor with:

@torchdynamo.optimize("inductor")
def foo(x):
    …

Or if you have an FX graph already you can call torchinductor.compile_fx(graph, example_inputs).

TorchInductor is still an early prototype, so you should expect to find bugs and rough edges. Many models still fail, and there are still missing features.

Please submit bug reports to the TorchDynamo github to help us improve, and we always welcome external contributions.

17 Likes

This is super exciting.

A few questions:

Does inductor_cudagraphs mean that it generates cuda graphs with kernels generated by triton?

How much does TorchInductor generate kernels “from scratch” with triton/etc. vs reusing the existing Torch kernels?

Can you talk about how op fusion works? E.g. can your inner_fn in the post be automatically fused with other “Pointwise” ops or even used as a fused activation function?

In inner_fn – where is size1 defined?

When you say that TorchInductor “is able to represent aliasing and mutation” – what does that mean? What I’ve found is that in practice most backends need to go through a purely functional form and then rely on a buffer allocation pass to make optimal decisions about whether things should reuse buffers or be views/etc. (the way the user wrote is not necessarily optimal).

Does inductor_cudagraphs mean that it generates cuda graphs with kernels generated by triton?

Correct, Triton has high CPU overheads, so cudagraphs is helps a lot and is needed. There is an upstream fix coming in Triton that allows AOT kernel generation and make cudagraphs less important.

How much does TorchInductor generate kernels “from scratch” with triton/etc. vs reusing the existing Torch kernels?

TorchInductor generates nearly all of its kernels automatically from scratch based on its IR.

The two exceptions are matmul/conv where it has a template with auto-generated epilogue fusions. In the current numbers these are disabled by config, and we are just using aten. I’m expecting another 10-20% speedup from enabling this.

There is also a small list of kernels we haven’t implemented yet and are using aten fallbacks: TorchInductor missing ops tracker · Issue #93757 · pytorch/pytorch · GitHub
but eventually we wan’t to codegen everything.

Can you talk about how op fusion works? E.g. can your inner_fn in the post be automatically fused with other “Pointwise” ops or even used as a fused activation function?

Yes, it will be automatically fused. Pointwise ops can be fused with: other pointwise ops; reduction ops; and matmul/conv templates. It also supports fusing multiple reductions/broadcasts together.

The key functions here are can_fuse which tests if two nodes can be fused together, and score_fusion which gives a priority that controls the order fusions happen in. Since some fusions can block other fusions, order matters.

In inner_fn – where is size1 defined?

There is a per-graph database of symbolic size variables defined in terms of the shapes of the inputs. This is handled in sizevars.py and uses sympy. For clarity, it is basically just:

size1 = sympy.Symbol("size1")

the symbol names get allocated all based on the inputs to the graph. So size1 might be input[0].size(2).

When you say that TorchInductor “is able to represent aliasing and mutation” – what does that mean?
What I’ve found is that in practice most backends need to go through a purely functional form and then rely on a buffer allocation pass to make optimal decisions about whether things should reuse buffers or be views/etc. (the way the user wrote is not necessarily optimal).

TorchInductor is “mostly functional,” but not purely functional. There isn’t a good way to represent scatter operations (which show up in backwards a lot) functionally while maintaining good performance. It is really easy to turn O(n) stuff into O(n^2) by trying to functionalize a chain of scatters that only mutate a small fraction of the elements of a tensor. There is also stuff like input mutation, where you don’t control the storage being mutated. The IR directly supports mutation and scatter, though we do make use of dispater level functionalization.

4 Likes

Really like this direction! Am also 100% for implementing TorchInductor in Python, hence making experimentation a lot lot easier than it being hidden away.

A key question which I would have it the path to GPU-Computing with TorchInductor. The way I read your post these are the currently planned paths:

  1. PyTorch → TorchDynamo → TorchInductor → Triton → NVIDIA GPU
  2. PyTorch → TorchDynamo → TorchInductor → OpenMP → CPU

Have you also considered potentially exposing OpenMP Device-Offloading to GPUs?

1 Like

Yes correct, those are the current paths.

I considered OpenMP’s GPU support, but have heard indirectly from multiple people that it performs poorly, so I abandoned the idea before starting. I’d be curious if you or others have had good experiences with it.

I suspect it would be easy to add.

I’d be happy to implement a prototype given the overlap with my current work, and see how it does in the above tests.

I had a brief glance at the TorchInductor repo, but am a little unsure if the C++/OpenMP path is already functional? What would the files be at which I would have to have a look for this?

Yes it is functional and used as the CPU backend. Here is an example:

import torch
import torchdynamo
import torchinductor.config
torchinductor.config.debug = True

@torchdynamo.optimize()
def addrelu(a, b):
    return torch.relu(torch.add(a, b))

addrelu(torch.randn(128, 8192), torch.randn(128, 8192))

If you run this it prints out the generated code:

$ python ex.py 
torchinductor.compile_fx: [INFO] Compiling FORWARDS graph

from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torchinductor.codecache import CppCodeCache, TritonCodeCache

aten = torch.ops.aten

import triton
import triton.language as tl

from torchinductor.triton_ops.autotune import pointwise_heuristics
from torchinductor.triton_ops.autotune import reduction_heuristics
from torchinductor.triton_ops.autotune import grid


kernel0 = CppCodeCache.load('''
#include "/tmp/torchinductor_jansel/i7/ci7dxnvwaxl7gpqj7v4mal2m4yuczvf7n52zpzflqmbs2dbau6lt.h"
extern "C" void kernel(const float* __restrict__ in_ptr0,
                       const float* __restrict__ in_ptr1,
                       float* __restrict__ out_ptr0,
                       const long ks0,
                       const long ks1)
{
    #pragma omp parallel
    {
        #pragma omp for
        for(long i0=0; i0<ks0*ks1; ++i0)
        {
            {
                {
                    auto tmp0 = in_ptr0[i0];
                    auto tmp1 = in_ptr1[i0];
                    auto tmp2 = tmp0 + tmp1;
                    auto tmp3 = tmp2 * (tmp2>0);
                    out_ptr0[i0] = tmp3;
                }
            }
        }
    }
}
''').kernel


def call(arg0_1, arg1_1):
    arg0_1_size = arg0_1.size()
    s0 = arg0_1_size[0]
    s1 = arg0_1_size[1]
    buf0 = empty_strided((s0, s1), (s1, 1), device='cpu', dtype=torch.float32)
    kernel0(c_void_p(arg0_1.data_ptr()), c_void_p(arg1_1.data_ptr()), c_void_p(buf0.data_ptr()), c_long(s0), c_long(s1))
    return (buf0, )


if __name__ == "__main__":
    from torchdynamo.testing import rand_strided
    from torchinductor.utils import print_performance
    arg0_1 = rand_strided((128, 8192), (8192, 1), device='cpu', dtype=torch.float32)
    arg1_1 = rand_strided((128, 8192), (8192, 1), device='cpu', dtype=torch.float32)
    print_performance(lambda: call(arg0_1, arg1_1))

torchinductor.graph: [INFO] Output code: /tmp/torchinductor_jansel/pq/cpqtzrpckt3fn5hxa3b3sjai77oa3i562eegdy6icepd6maz2gah.py

Most if the C++ backend code is in: torchdynamo/cpp.py at main ¡ pytorch/torchdynamo ¡ GitHub

You can join #torchdynamo on the PyTorch slack if you want to ask questions more real time. Feel free to ping me with your email if you need an invite.

2 Likes

Exciting work. One question:
Why do you choose triton as the GPU code-gen instead of TVM. What are your considerations?

@gong_chen if you (or someone else) wants to try building a TVM backend for TorchInductor, I’d be curious to see the results. I think it would be fairly straightforward, but would require integrating at lower level than Relay IR – so it would require knowledge of TVM internals.

Both nvFuser and NNC have Halide/TVM inspired IRs. Additionally, there is a TVM backend for TorchDynamo using the legacy TorchScript export bindings.

On NVIDIA GPUs, we have observed better performance results from Triton than TVM on most models. Though I don’t think it is apples to apples because a lot is lost in the existing export paths to TVM, and TVM performance varies greatly depending on autotuning. I think the strength of TVM is in its many non-GPU execution targets, while Triton is GPU-only.

2 Likes

@jansel Thank you for your reply. Yes, TorchInductor is required at level than Relay IR. What I mean is using TVM script(Tensor IR) . I think TVM script is the same level with triton, which both create a DSA for tensor programing on python AST. I am not very sure why Triton can get better performance than TVM on GPU. May it be easy to involve expert experiences to do manual tiling? BTW, I am curious about why you use “Inductor” to name your work.

For more on Triton I’d suggest the Triton Paper, which might give more insight into why it is faster than TVM. Triton seems to need a lot less tuning.

The name is to continue the theme of TorchDynamo, which is a reference to DynamoRIO.

1 Like

Super interesting initiative here! Would like to know more about adding a TVM backend for torchinductor here. @jansel Do you think it is possible that you can send me a link to the pytorch slack. My email is yuanjing@octoml.ai.

Hi !
Thanks for the exciting new addition to the Pytorch ecosystem :slight_smile:
I have a few questions regarding the scope of the project:

  • would Inductor work well with for loops in regard to the current JIT poor performance ? (it would seems based on the previous answers)
  • if not, would it take into account the user defined CUDA kernels for optimization ?
  • could it infer sparsity in the results and automatically discards useless computations in the graph ? (e.g. some part of the graph only retains the diagonal of a tensor; or a tensor being symmetric by construction)
  • there are some hand defined Triton kernel from the previous answers (e.g. Conv2D). Does this mean there would be no further fusion on these kernels ? (use case: 3 convolutions with the same input, and the same weights transformed 3 times: absolute value, clipped at 0;+ and clipped at -:0)

would Inductor work well with for loops in regard to the current JIT poor performance ? (it would seems based on the previous answers)
if not, would it take into account the user defined CUDA kernels for optimization ?

Do you mean loops in python code like defining a custom kernel? For that I’d recommend checking out Triton. You could think of Triton like a better Cuda.

There is no reason custom ops shouldn’t work, though might need some minor glue code.

could it infer sparsity in the results and automatically discards useless computations in the graph ? (e.g. some part of the graph only retains the diagonal of a tensor; or a tensor being symmetric by construction)

Yes this would be possible, but future work and not currently planned.

there are some hand defined Triton kernel from the previous answers (e.g. Conv2D). Does this mean there would be no further fusion on these kernels ? (use case: 3 convolutions with the same input, and the same weights transformed 3 times: absolute value, clipped at 0;+ and clipped at -:0)

Currently, we only support epilogue fusion into Triton templates.

1 Like

@jansel Hi looks like the performance of inductor benefits a lot from cuda graph.
I am curious how inductor handles dynamic shape?
cuda graph cannot deal with those ops with dynamic input/output shapes, which are common in the framework.

For dynamic shapes it doesn’t use cudagraphs.

Since writing this post, there is actually an entirely new Triton runtime which reduces the reliance on cudagraphs a lot. The plan is to lower CPU overheads to the point where cudagraphs isn’t needed for competitive performance.

is there any post or links that we can read on this feature that can lower the CPU overheads?

The changes that already landed:

The next lowest hanging fruit is:

(which I believe is in progress)

How does TorchInductor handle aliasing and mutation, does it always perform a functionalization, or are they exposed to the backend compilers? The description above mentions the concept of TensorBox and StorageBox, are there more details on these?

TorchInductor uses functionalization, but there are some things (like mutating inputs and scatter operations) that are not functionalized.

TensorBox maps to torch.Tensor. StorageBox maps to torch.Storage. The core representation is a pointer plus strides, same as eager PyTorch. Basically the compiler abstractions match the eager mode abstractions one-to-one.