The PyTorch team has been building TorchDynamo, which helps to solve the graph capture problem of PyTorch with dynamic Python bytecode transformation. To actually make PyTorch faster, TorchDynamo must be paired with a compiler backend that converts the captured graphs into fast machine code. We have integrated numerous backends already, and built a lightweight autotuner to select the best backend for each subgraph.
Unfortunately a lot is lost in translation when exporting to different backends that differ greatly from PyTorch. Many involve multiple conversion steps, have a fundamentally different execution model than PyTorch, and have limited support for many PyTorch operators and features. Another key challenge is with a few exceptions, most backends are inference only and have made design decisions that make training support intractable. This export-based execution model will be important in a lot of specific applications, but PyTorch needs a native compiler with abstractions that closely mirror those of PyTorch.
In various updates, you have seen updates about our PyTorch-native compilers nvFuser and NNC.
In this post, we will introduce TorchInductor. TorchInductor is a new compiler for PyTorch, which is able to represent all of PyTorch and is built in a general way such that it will be able to support training and multiple backend targets. TorchInductor is able to represent aliasing and mutation by having the concept of TensorBox
and StorageBox
that map one-to-one with torch.Tensor
and torch.Storage
. It is able to handle views by having a symbocally strided tensor that maps directly from the native torch.Tensor
stride representation, which makes views easy to handle. Other parts of PyTorch are handled similarly, by mirroring the data model of PyTorch in the backend. The design philosophy is a thin, easily hackable, way of symbolically mapping PyTorch to lower level backends and enabling rapid experimentation, autotuning between different backends, and higher level optimizations such as memory planning.
TorchInductor Design
TorchInductor is implemented in Python. There are pros and cons to this choice, but we have found this choice greatly increased velocity and developer productivity. We also have observed that the PyTorch community is much more likely to contribute to parts of PyTorch written in Python, and therefore it makes the system more approachable and hackable by our users.
To force the design of TorchInductor to be general, we are starting off with two lower level execution targets, that represent different points in the design space:
- Triton is a new programming language that provides much higher productivity than CUDA, but with the ability to beat the performance of highly optimized libraries like cuDNN with clean and simple code. It is developed by Philippe Tillet at OpenAI, and is seeing enormous adoption and traction across the industry. Triton supports NVIDIA GPUs, and is quickly growing in popularity as a replacement for hand written CUDA kernels.
- C++/OpenMP is a widely adopted specification for writing parallel kernels. OpenMP provides a work sharing parallel execution model, and enables support for CPUs. C++ is also an interesting target in that it is a highly portable language and could enable export to more exotic edge devices and hardware architectures.
The approach to building TorchInductor is a breadth-first one. We have spent most of our time make sure the core infrastructure is able to support the vast majority of PyTorch, including: aliasing/mutation/views, scatter (indirect writes), gather (indirect reads), pooling/windows/reductions, masked/conditional execution (padding, etc), template epilogue fusions, tiling, and horizontal/vertical fusions. So far we have not spent too much time optimizing any one pattern, but focused on general optimizations with widespread benefits.
TorchInductor supports dynamic shapes and strides using the SymPy symbolic math library. It specializes on zero and one, but for other unique tensor sizes it will assign them to a sympy.Symbol and flow them through the entire program. Memory loads and stores are represented directly as sympy indexing formulas based on the iteration variables and symbolic tensor sizes. TorchInductor will introduce guards when needed that lift any assumptions/requirements to the top of the subgraph and will trigger a recompile if those guards fail.
Another unique aspect of TorchInductor is it uses a define-by-run loop-level intermediate representation (IR). Many parts of the IR are Python callables that take SymPy expressions as inputs. We analyze this IR and do codegen by changing the implementation of ops.*
and running the IR. As an example, the IR for x.permute(1, 0) + x[2, :]
might be something like:
def inner_fn(index: List[sympy.Expr]):
i1, i0 = index
tmp0 = ops.load("x", i1 + i0*size1)
tmp1 = ops.load("x", 2*size1 + i0)
return ops.add(tmp0, tmp1)
torchinductor.ir.Pointwise(
device=torch.device("cuda"),
dtype=torch.float32,
inner_fn=inner_fn,
ranges=[size0, size1],
)
Where inner_fn defines how to compute a single element of the output buffer.
Training Performance Results
Full (and possibly more up to date if you are reading this in future) performance results can be found in the TorchDynamo Performance Dashboard. The dashboard contains NVIDIA A100 training speedups over eager for float32, float16, and automatic mixed precision (AMP). I will highlight just AMP results here:
Pass rate for AMP on a NVIDIA A100 GPU
+---------------------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+---------------------+------------+-------------+-------------+
| eager | 98%, 48/49 | 100%, 43/43 | 100%, 68/68 |
| ts_nvfuser | 92%, 45/49 | 95%, 41/43 | 74%, 50/68 |
| aot_eager | 86%, 42/49 | 100%, 43/43 | 97%, 66/68 |
| aot_cudagraphs | 84%, 41/49 | 84%, 36/43 | 94%, 64/68 |
| aot_nvfuser | 61%, 30/49 | 40%, 17/43 | 49%, 33/68 |
| inductor_cudagraphs | 71%, 35/49 | 86%, 37/43 | 72%, 49/68 |
+---------------------+------------+-------------+-------------+
Geometric mean speedup over eager (of passing models) for AMP on a NVIDIA A100 GPU
+---------------------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+---------------------+------------+-------------+-------------+
| eager | 1.0x | 1.01x | 1.0x |
| ts_nvfuser | 1.04x | 1.04x | 1.03x |
| aot_eager | 1.0x | 1.0x | 1.0x |
| aot_cudagraphs | 1.17x | 1.35x | 1.06x |
| aot_nvfuser | 1.2x | 1.38x | 1.32x |
| inductor_cudagraphs | 1.9x | 2.17x | 1.69x |
+---------------------+------------+-------------+-------------+
For TorchBench models, 35 out of 49 currently run correctly in training and of those working we see a 1.90x geomean speedup.
For Hugging Face models, 37 out of 43 models run correctly, providing a 2.17x geomean speedup.
For TIMM models, 49 out of 68 run correctly, providing a 1.69x geomean speedup.
For baselines to compare against, we have:
- eager: baseline that runs the captured FX graph using PyTorch eager mode. This measures the overheads of TorchDynamo.
- ts_nvfuser: nvFuser using its older TorchScript based backend
- aot_eager: baseline that runs AOT Autograd using a PyTorch eager backend, to measure overheads of AOT Autograd.
- aot_cudagraphs: An AOT Autograd backend that applies cudagraphas to reduce overheads
- aot_nvfuser: nvFuser using its newer AOT Autograd based backend
- inductor_cudagraphs: TorchInductor, as described in this post.
All of these backends use TorchDynamo as a frontend in this experiment, so they get the same graphs to optimize.
Conclusions
You can check out the TorchInductor source code. To use it you will need to install from source (or use nightlies) of TorchDynamo, Triton, and PyTorch. Once setup, you can try TorchInductor with:
@torchdynamo.optimize("inductor")
def foo(x):
âŚ
Or if you have an FX graph already you can call torchinductor.compile_fx(graph, example_inputs)
.
TorchInductor is still an early prototype, so you should expect to find bugs and rough edges. Many models still fail, and there are still missing features.
Please submit bug reports to the TorchDynamo github to help us improve, and we always welcome external contributions.