Python Operator Authoring w/ NNC

with @bertmaher and @ZolotukhinM

TLDR; we built a prototype that allows defining pointwise operators in Python backed by a NNC/FX-based JIT. It supports dynamic shapes and demonstrates up to a 3.5x speedup over torch.add on CPU and 1.5x on GPU. It can also be used as an explicit pointwise fusion API, where it beats the performance of existing TorchScript+NNC fuser by being lower overhead.

Motivation

With torch::deploy launching and the shift towards using PyTorch eager in more production settings, the PyTorch Compiler team has been thinking about how we can use compiler techniques to improve eager mode where we don’t have access to a whole program graph.

One of the ideas we have been thinking about is a new way to define operators in Python, and have those operators backed by a JIT compiler. This approach has many benefits:

  • Operator definitions in Python are more hackable and easier to maintain
  • Compilers can make single operators faster (data demonstrating a faster torch.add below)
  • This same interface also be used as an explicit fusion API
  • Smaller binary sizes (useful for mobile)
  • Easier to add new architectures
  • Operators in a compiler IR make doing O(num-operators) changes to PyTorch easier
    • New dtypes, vector instructions (AVX512), sparse, vmap, distributed, etc

It also has a few downsides:

  • Initial warm up time could be an issue.
    • We think this is solvable with pre-populated on-disk caches, not over-specializing, background thread compilation, and other techniques.
  • Some environments can’t support dynamic compilation (e.g. mobile)
    • We will need to build an AOT compiled version and have a way to give a quiescence guarantee to users when required.

User Interface

To help drive discussion, gather data, and make this idea more concrete, we built a working prototype. The prototype allows you to define a new pointwise operator in Python:

@torch.jit.te.pointwise_operator
def add(a, b):
    return a + b

This operator will be JIT compiled and is nearing feature parity with exiting TensorIterator-backed pointwise operators. It currently supports: CPU/GPU, broadcasting, type promotion, shape checking, strides, out variants, some aliasing, some backwards/autograd, etc. In the cases where there are missing features, it still tries to do all relevant checks to properly simulate the cost of implementing those features.

You can use this same interface to define custom fused operators, for example:

@torch.jit.te.pointwise_operator
def fused_addnorm(a, b, m, d):
    return (a + b - m) / d

We imagine extending this interface to handle composite operators, reductions, and other more complex types of ops.

Implementation and Specialization

The high level design is:

  • Collect all the Tensor metadata inputs needed to codegen dispatcher, autograd, and TensorIterator functionality into a SpecializationKey object for each tensor. This contains a lot more data than the current dispatch key, but does not include static shapes. (The prototype computes the key dynamically, but it could be cached on the Tensor.)
  • Do a single persistent cache lookup to get precompiled implementation for a tuple of SpecializationKeys (one for each input/output to the op)
    • On a miss, call the JIT compiler to codegen a specialized implementation and add it to the cache.
  • Jump to the specialized implementation

This allows us to bypass much of overhead and complexity of the existing call path, while still being able to handle the weird corner cases.

Picking the right data to go into the SpecializationKey is key here. If we don’t specialize enough, we will be forced to add checks to the fast-path. If we specialize too much, recompiling could become an issue. Here is an example of what the prototype SpecializationKey key looks like for add(rand(1, 8), rand(8, 8).transpose(0, 1)):

[SpecializationKey(
    alias_group=0,
    ndim=2,
    dtype=torch.float32,
    device=device(type='cpu'),
    layout=torch.strided,
    requires_grad=False,
    out=False,
    shape=['one', 'other'],
    stride=['contiguous', 'one']),
 SpecializationKey(
    alias_group=0,
    ndim=2,
    dtype=torch.float32,
    device=device(type='cpu'),
    layout=torch.strided,
    requires_grad=False,
    out=False,
    shape=['other', 'other'],
    stride=['one', 'transposed_contiguous'])]

There is one key for each input. The fields are as follows:

  • ndim/dtype/device/layout/requires_grad should be self-explanatory.
  • alias_group tracks aliasing relationships between inputs:
    • 0 means no aliasing.
    • A positive value means “simple aliasing”, where all inputs with the same alias_group ID point to the same data/shapes/strides and can be folded into a single input to the kernel to generate better code.
    • A negative value (-ID) means “complex aliasing” (where the data overlaps, but has different shapes/strides), the grouping logic is the same.
  • out indicates output tensors (for auto-generated out= variants)
  • shape is an enum value for each dimension
    • “one” means the size of the dimension is 1. We need to know this to compute broadcasting at compile time.
    • “other” is any value not equal to 1.
  • stride is an enum value for each dimension:
    • “zero” means the stride is exactly 0 (broadcasting).
    • “one” means the stride is exactly 1 (needed for good vectorization).
    • “contiguous” means stride[i] == stride[i+1]*size[i+1] (to compress iteration loops symbolically).
    • “transposed_contiguous” means stride[i] == stride[i-1]*size[i-1] (to compress iteration loops symbolically).
    • “as_arg” means the stride is something else and must be passed into the generated kernel.

For speed this is implemented with packed bit-vectors. This key made need some tweaking, but we think it strikes a good balance as starting point for discussion.

Performance Results

The chart below show speedups comparing the performance of our prototype add() to the existing torch.add() on a wide variety of input types. We show both CPU (1-thread, Coffee Lake) and GPU (GTX 1070) results for sizes 1x1, 512x512, 8192x8192. The 1x1 size is meant to measure overheads, while the larger sizes are showing generated code quality. I ran each version hundreds of times (thousands for smaller sizes) and report the median speedup. You can find the definition of each experiment here.

With only a few exceptions, our prototype is faster or the same as torch.add. In most case the speedup is a few percent, but there are some cases where the speedup is up to 3.5x. For CPU, we see huge speedups on type promotion test cases. For GPU, we see big speedups (up to 1.5x) on broadcasting test cases.

Pointwise Fusion versus TorchScript

The last 3 bars show a small example of an explicit pointwise fusion. For the non-out-variant ones, we have an existing fuser in TorchScript that can also fuse this example:

# This prototype:
@torch.jit.te.pointwise_operator
def fused_addnorm(a, b, m, d):
    return (a + b - m) / d

# Same algorithm with the existing TorchScript-based fuser:
torch._C._jit_override_can_fuse_on_cpu(True)  
@torch.jit.script
def fused_addnorm_ts(a, b, m, d):
    return (a + b - m) / d
    

Here is a performance comparison showing speedups over eager (unfused) as a baseline and speedups over TorchScript fusion as a second baseline.

CPU speedups over either eager or TS
                       1x1    512x512 8192x8192
forward (vs unfused)   3.41x  1.99x   2.39x
forward (vs TS-fused)  3.16x  1.07x   1.00x
backward (vs unfused)  1.14x  1.57x   1.67x
backward (vs TS-fused) 1.70x  1.12x   1.00x

GPU speedups over either eager or TS
                       1x1    512x512 8192x8192
forward (vs unfused)   1.94x  1.36x   1.92x
forward (vs TS-fused)  1.48x  1.14x   1.00x
backward (vs unfused)  1.18x  1.18x   1.36x
backward (vs TS-fused) 1.33x  1.32x   1.00x

We can see that for 8192x8192, the two fusers perform the same (1.00x) for both CPU/GPU and forward/backward, but for smaller sizes this prototype is faster than TorchScript because it has much lower overheads. This prototype also supports dynamic shapes without needing to recompile, while the TorchScript version shape specializes.

Next Steps

This is still an early prototype and not yet production ready. There are still a ton of challenges left to overcome, optimization opportunities, and integration issues. We are hoping this will start a discussion and help get feedback on this new direction.

12 Likes

Thanks for sharing.

What’s the status of this project?

What’s a summary of the additional overhead involved in calling the equivalent TorchScript code vs this approach?

Are those same TorchScript overheads responsible for the slowness of torch.add, or is that due to something else?

Some of the work started by this project has been carried forward in the functorch repo with the early AOT Autograd / compilation work there.

Overheads come from lots of different places and add up through many layers of the stack.

Is it basically the same overheads in the “Pointwise Fusion versus TorchScript” and torch.add cases?

The project has evolved to be more focused on training. The low-overhead caching part is similar though.

Thanks for sharing.
Compared with the TS-fused, what overheads does this prototype reduce :face_with_monocle:?