TLDR; we built a prototype that allows defining pointwise operators in Python backed by a NNC/FX-based JIT. It supports dynamic shapes and demonstrates up to a 3.5x speedup over
torch.add on CPU and 1.5x on GPU. It can also be used as an explicit pointwise fusion API, where it beats the performance of existing TorchScript+NNC fuser by being lower overhead.
With torch::deploy launching and the shift towards using PyTorch eager in more production settings, the PyTorch Compiler team has been thinking about how we can use compiler techniques to improve eager mode where we don’t have access to a whole program graph.
One of the ideas we have been thinking about is a new way to define operators in Python, and have those operators backed by a JIT compiler. This approach has many benefits:
- Operator definitions in Python are more hackable and easier to maintain
- Compilers can make single operators faster (data demonstrating a faster
- This same interface also be used as an explicit fusion API
- Smaller binary sizes (useful for mobile)
- Easier to add new architectures
- Operators in a compiler IR make doing
O(num-operators)changes to PyTorch easier
- New dtypes, vector instructions (AVX512), sparse, vmap, distributed, etc
It also has a few downsides:
- Initial warm up time could be an issue.
- We think this is solvable with pre-populated on-disk caches, not over-specializing, background thread compilation, and other techniques.
- Some environments can’t support dynamic compilation (e.g. mobile)
- We will need to build an AOT compiled version and have a way to give a quiescence guarantee to users when required.
To help drive discussion, gather data, and make this idea more concrete, we built a working prototype. The prototype allows you to define a new pointwise operator in Python:
@torch.jit.te.pointwise_operator def add(a, b): return a + b
This operator will be JIT compiled and is nearing feature parity with exiting TensorIterator-backed pointwise operators. It currently supports: CPU/GPU, broadcasting, type promotion, shape checking, strides, out variants, some aliasing, some backwards/autograd, etc. In the cases where there are missing features, it still tries to do all relevant checks to properly simulate the cost of implementing those features.
You can use this same interface to define custom fused operators, for example:
@torch.jit.te.pointwise_operator def fused_addnorm(a, b, m, d): return (a + b - m) / d
We imagine extending this interface to handle composite operators, reductions, and other more complex types of ops.
The high level design is:
- Collect all the Tensor metadata inputs needed to codegen dispatcher, autograd, and TensorIterator functionality into a
SpecializationKeyobject for each tensor. This contains a lot more data than the current dispatch key, but does not include static shapes. (The prototype computes the key dynamically, but it could be cached on the Tensor.)
- Do a single persistent cache lookup to get precompiled implementation for a tuple of
SpecializationKeys (one for each input/output to the op)
- On a miss, call the JIT compiler to codegen a specialized implementation and add it to the cache.
- Jump to the specialized implementation
This allows us to bypass much of overhead and complexity of the existing call path, while still being able to handle the weird corner cases.
Picking the right data to go into the
SpecializationKey is key here. If we don’t specialize enough, we will be forced to add checks to the fast-path. If we specialize too much, recompiling could become an issue. Here is an example of what the prototype
SpecializationKey key looks like for
add(rand(1, 8), rand(8, 8).transpose(0, 1)):
[SpecializationKey( alias_group=0, ndim=2, dtype=torch.float32, device=device(type='cpu'), layout=torch.strided, requires_grad=False, out=False, shape=['one', 'other'], stride=['contiguous', 'one']), SpecializationKey( alias_group=0, ndim=2, dtype=torch.float32, device=device(type='cpu'), layout=torch.strided, requires_grad=False, out=False, shape=['other', 'other'], stride=['one', 'transposed_contiguous'])]
There is one key for each input. The fields are as follows:
- ndim/dtype/device/layout/requires_grad should be self-explanatory.
alias_group tracks aliasing relationships between inputs:
- 0 means no aliasing.
- A positive value means “simple aliasing”, where all inputs with the same
alias_groupID point to the same data/shapes/strides and can be folded into a single input to the kernel to generate better code.
- A negative value (-ID) means “complex aliasing” (where the data overlaps, but has different shapes/strides), the grouping logic is the same.
- out indicates output tensors (for auto-generated out= variants)
shape is an enum value for each dimension
- “one” means the size of the dimension is 1. We need to know this to compute broadcasting at compile time.
- “other” is any value not equal to 1.
stride is an enum value for each dimension:
- “zero” means the stride is exactly 0 (broadcasting).
- “one” means the stride is exactly 1 (needed for good vectorization).
- “contiguous” means stride[i] == stride[i+1]*size[i+1] (to compress iteration loops symbolically).
- “transposed_contiguous” means stride[i] == stride[i-1]*size[i-1] (to compress iteration loops symbolically).
- “as_arg” means the stride is something else and must be passed into the generated kernel.
For speed this is implemented with packed bit-vectors. This key made need some tweaking, but we think it strikes a good balance as starting point for discussion.
The chart below show speedups comparing the performance of our prototype
add() to the existing
torch.add() on a wide variety of input types. We show both CPU (1-thread, Coffee Lake) and GPU (GTX 1070) results for sizes 1x1, 512x512, 8192x8192. The 1x1 size is meant to measure overheads, while the larger sizes are showing generated code quality. I ran each version hundreds of times (thousands for smaller sizes) and report the median speedup. You can find the definition of each experiment here.
With only a few exceptions, our prototype is faster or the same as
torch.add. In most case the speedup is a few percent, but there are some cases where the speedup is up to 3.5x. For CPU, we see huge speedups on type promotion test cases. For GPU, we see big speedups (up to 1.5x) on broadcasting test cases.
The last 3 bars show a small example of an explicit pointwise fusion. For the non-out-variant ones, we have an existing fuser in TorchScript that can also fuse this example:
# This prototype: @torch.jit.te.pointwise_operator def fused_addnorm(a, b, m, d): return (a + b - m) / d # Same algorithm with the existing TorchScript-based fuser: torch._C._jit_override_can_fuse_on_cpu(True) @torch.jit.script def fused_addnorm_ts(a, b, m, d): return (a + b - m) / d
Here is a performance comparison showing speedups over eager (unfused) as a baseline and speedups over TorchScript fusion as a second baseline.
CPU speedups over either eager or TS 1x1 512x512 8192x8192 forward (vs unfused) 3.41x 1.99x 2.39x forward (vs TS-fused) 3.16x 1.07x 1.00x backward (vs unfused) 1.14x 1.57x 1.67x backward (vs TS-fused) 1.70x 1.12x 1.00x GPU speedups over either eager or TS 1x1 512x512 8192x8192 forward (vs unfused) 1.94x 1.36x 1.92x forward (vs TS-fused) 1.48x 1.14x 1.00x backward (vs unfused) 1.18x 1.18x 1.36x backward (vs TS-fused) 1.33x 1.32x 1.00x
We can see that for 8192x8192, the two fusers perform the same (1.00x) for both CPU/GPU and forward/backward, but for smaller sizes this prototype is faster than TorchScript because it has much lower overheads. This prototype also supports dynamic shapes without needing to recompile, while the TorchScript version shape specializes.
This is still an early prototype and not yet production ready. There are still a ton of challenges left to overcome, optimization opportunities, and integration issues. We are hoping this will start a discussion and help get feedback on this new direction.