TL;DR

We built the “jiterator,” a CUDA TensorIterator interface that allows the elementwise CUDA kernels to be just-in-time compiled when used. This lets us greatly reduce the build time of these operators and means they have no impact on CUDA context size unless used. Today in PyTorch the `digamma`

, `trigamma`

, `lgamma`

, `gcd`

, `lcm`

, `i0`

, `i0e`

, `i1`

, `i1e`

, `ndtri`

, `erfcx`

and `zeta`

operations are “jiterated.” These operators behave just like they did when precompiled and have comparable performance after they’re just-in-time compiled. Jiterating these operations alone has reduced PyTorch’s initial CUDA context size by 3%, and a few additional targeted optimizations identified while working on the jiterator have produced a total reduction of 7% since December.

Just-in-time compilation is not without its own challenges and drawbacks, however, and this note discusses the design of the jiterator, its impact, and planned future work. We’re excited to share the first “just-in-time” compilation system for PyTorch’s eager mode and how it’s keeping PyTorch maintainable even as we continue to add operators.

### GOALS

The Jiterator is intended to do three things:

Faithfully reproduce PyTorch’s current eager behavior and performance while just-in-time compiling elementwise CUDA operators

Reduce PyTorch’s build time and keep the build time reasonable as we add operators

Reduce PyTorch’s initial CUDA context size and keep it reasonable as we add operators

To achieve these goals we reuse TensorIterator and carefully mimic the behavior of Loops.cuh and CUDALoops.cuh to just-in-time compile the same types of kernels that are precompiled today. We use nvrtc for just-in-time compilation of CUDA C strings because it’s a familiar technology used by existing code generators. This approach provides an easy-to-use TensorIterator-compatible interface that creates operations which behave and (after just-in-time compilation) perform just like their precompiled counterparts.

### THE JITERATOR

Every op we add to PyTorch comes with costs, and two of those costs are increased build time and a larger CUDA context size if the operator has its own CUDA kernel. There are a variety of methods we can use to mitigate these impacts, but the ultimate hammer is just-in-time compiling — AKA “jitting” — operations when they’re used. Jitting ops is not without challenges and drawbacks of its own, however:

Just-in-time compilation techniques, like passing a CUDA C string to nvrtc, have different limitations than traditional precompilation of CUDA C

Practical just-in-time compilation techniques (using nvrtc or writing ptx and compiling with ptxas) have significant compilation times that makes the first call to just-in-time compiled operators slow

Matching precompiled performance requires identifying the number of kernels that can be just-in-time compiled

The next sections elaborate on each of the challenges and how the jiterator addresses them (or plans to address them).

### A TENSORITERATOR-COMPATIBLE INTERFACE

PyTorch already has a lot of architecture to facilitate writing elementwise kernels using dispatch macros and TensorIterator. The relevant snippets for the CUDA implementation of `gcd`

(greatest common divisor) after a TensorIterator object has been constructed are:

```
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
return calc_gcd(a, b);
});
});
template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) {
scalar_t a = ::abs(a_in);
scalar_t b = ::abs(b_in);
while (a != 0) {
scalar_t c = a;
a = b % a;
b = c;
}
return b;
}
```

These snippets describe gcd’s “functor” — the operation called in a TensorIterator kernel that takes care of tensors’ shapes, contiguities, and type promotion (which are encoded in the `iter`

object). The TensorIterator interface abstracts a tremendous amount of complexity away for developers, making it straightforward to define the gcd operation as a functor of two scalars.

Now let’s look at the analogous snippets in the jiterated version of `gcd`

:

```
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
jitted_gpu_kernel</*name=*/gcd_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 2>(iter, gcd_string);
});
const auto gcd_string = jiterator_stringify(
template <typename T>
T gcd(const T a_in, const T b_in) {
T a = abs(a_in);
T b = abs(b_in);
while (a != T{0}) {
T c = a;
a = b % a;
b = c;
}
return b;
}
); // gcd_string
```

The code is pretty similar. The dispatch macro and functor are effectively the same (although the functor is now passed to a macro). The call to jitted_gpu_kernel() may look a little intimidating, but it just describes some properties of the gcd operation explicitly — like how many arguments the functor takes. To create the complete kernel the jiterator inserts the string representation of the gcd functor into a string representation of a TensorIterator kernel before shipping the entire string to nvrtc.

We think using the jiterator is comparable to writing precompiled CUDA kernels, and we’ve successfully ported some of our most complicated elementwise CUDA kernels to it.

### COMPILATION TIME

When gcd was precompiled it took about 58 microseconds to run (on a devfair, with CUDA long tensors with 2^16 elements). The first time the jiterated gcd op is called, however, it takes over 2500 times as long to run! This happens because the first call to gcd has to just-in-time compile the kernel. Subsequent calls to gcd that can reuse that kernel take the same time as when gcd was precompiled, however.

Calls to JAX operations exhibit the same pattern, but frameworks like JAX aren’t intended for eagerly running operations like PyTorch is. Compilation time may be especially painful for users progressively writing and incrementally validating their scripts because every run triggers the same just-in-time compilations. To limit the impact of just-in-time compilation performance we’re only jiterating little used operators ahead of PyTorch 1.11. Better just-in-time compilation strategies are an interesting direction for Future Work (see below).

### MATCHING PRECOMPILED PERFORMANCE (AFTER THE INITIAL COMPILATION)

While the first call to a just-in-time compiled operator is slow, subsequent calls can be as fast as if the operator were precompiled. The trick to this identifying the set of kernels we could generate in all possible scenarios and precompiling the dispatch to them.

For non-jiterated operations PyTorch compiles a fixed set of kernels when built. Elementwise TensorIterator kernels get six for every dtype they support — one for arbitrary strides, another for arbitrary strides with type promotion, another for contiguous inputs with type casting, and finally three vectorized kernels (with different vectorization lengths) for contiguous inputs without type casting. So if an elementwise operators supports two datatypes (like float and double) it compiles 12 kernels, and if it supports four it compiles 24, etc. For elementwise binary operations that support scalars three times as many kernels are precompiled (36 instead of 12, 72 instead of 24)! That’s a lot of kernels, and it’s easy to see why especially complicated kernels, like the zeta or erfcx operations, might take a long to compile traditionally.

The jiterator can compile the same set of kernels, and because this set is known when PyTorch is built we can construct the runtime mapping from all inputs to all possible kernels at build-time, too, making dispatch to cached just-in-time compiled kernels as fast as non-jiterated dispatch. Callgrind benchmarking shows that cpu instruction count is actually slightly smaller when a jit-compiled kernel is launched than when the same operation is launch the regular way (182565 instructions for 10 iterations with jiterator vs 192966 instructions without). Timing the cpu overhead of launching a kernel is very noisy, but shows results around 5 us in both cases. A more complete performance analysis that measures the performance of the GPU kernel shows there are some input shapes where the jiterator’s kernels are a shade faster than eager kernels, and others where they’re a shade slower, but they’re effectively the same. P474088187 shows achieved bandwidth for the gcd operation for the different sizes and broadcasting patters with regular and jitted implementation.

### RECAP & FUTURE WORK

The jiterator is PyTorch’s first just-in-time compilation system that runs elementwise CUDA eager operations. It provides a TensorIterator-compatible developer interface, can just-in-time generate the same set of kernels we precompile today, and once those kernels are compiled delivers performance comparable to precompiled operators.

Today a dozen operations are jiterated — digamma, trigamma, lgamma, gcd, lcm, i0, i0e, i1, i1e, ndtri, erfcx and zeta. And delaying the compilation of their CUDA kernels has reduced the CUDA context size of PyTorch’s native kernels by 7% (37.5 MB) since December. For builds with dynamic linking, about 2/5th of PyTorch’s initial CUDA context is native operations, however, so this leads to a more modest 3% improvement overall (see this workplace post for a more detailed breakdown of PyTorch’s CUDA context). While working on the jiterator we also identified additional opportunities to reduce our CUDA context size, leading to a total reduction of 62 MB (12%) in context size due to PyTorch native kernels, which translates to 7% improvement (60 MB) in overall context size.

Cold build times also improved. On the commit preceding the jiterator, a 20 core cold build took 17 minutes wall time, and 318 minutes user time. Cold builds now take 14 minutes wall time and 277 minutes user time (a decrease of ~17.5% in real time). Of course other commits over the past month may also have impacted build time.

Looking ahead, the biggest challenge for just-in-time compilation of eager operators remains just-in-time compilation speed. One idea to address this issue is creating a cache that’s persistent across processes, another approach may be alternate compilation strategies (Zach discusses some ideas in his most recent MinTorch post).

We also plan to continue extending and tuning the jiterator while working on just-in-time compilation speed. There are a few types of elementwise operations the jiterator doesn’t support yet, and PyTorch doesn’t have the headers nvrtc needs for complex math (our colleagues at NVIDIA are working on this problem, too, for nvFuser). Supporting reductions and even making it easy to jit any CUDA kernel is interesting, too.

If you have questions, comments or are interested in working on some of these issues please reach out by commenting below or contacting us directly!