Keeping PyTorch's Ops Maintainable: The Jiterator

TL;DR
We built the “jiterator,” a CUDA TensorIterator interface that allows the elementwise CUDA kernels to be just-in-time compiled when used. This lets us greatly reduce the build time of these operators and means they have no impact on CUDA context size unless used. Today in PyTorch the digamma, trigamma, lgamma, gcd, lcm, i0, i0e, i1, i1e, ndtri, erfcx and zeta operations are “jiterated.” These operators behave just like they did when precompiled and have comparable performance after they’re just-in-time compiled. Jiterating these operations alone has reduced PyTorch’s initial CUDA context size by 3%, and a few additional targeted optimizations identified while working on the jiterator have produced a total reduction of 7% since December.

Just-in-time compilation is not without its own challenges and drawbacks, however, and this note discusses the design of the jiterator, its impact, and planned future work. We’re excited to share the first “just-in-time” compilation system for PyTorch’s eager mode and how it’s keeping PyTorch maintainable even as we continue to add operators.

GOALS

The Jiterator is intended to do three things:
Faithfully reproduce PyTorch’s current eager behavior and performance while just-in-time compiling elementwise CUDA operators
Reduce PyTorch’s build time and keep the build time reasonable as we add operators
Reduce PyTorch’s initial CUDA context size and keep it reasonable as we add operators
To achieve these goals we reuse TensorIterator and carefully mimic the behavior of Loops.cuh and CUDALoops.cuh to just-in-time compile the same types of kernels that are precompiled today. We use nvrtc for just-in-time compilation of CUDA C strings because it’s a familiar technology used by existing code generators. This approach provides an easy-to-use TensorIterator-compatible interface that creates operations which behave and (after just-in-time compilation) perform just like their precompiled counterparts.

THE JITERATOR

Every op we add to PyTorch comes with costs, and two of those costs are increased build time and a larger CUDA context size if the operator has its own CUDA kernel. There are a variety of methods we can use to mitigate these impacts, but the ultimate hammer is just-in-time compiling — AKA “jitting” — operations when they’re used. Jitting ops is not without challenges and drawbacks of its own, however:
Just-in-time compilation techniques, like passing a CUDA C string to nvrtc, have different limitations than traditional precompilation of CUDA C
Practical just-in-time compilation techniques (using nvrtc or writing ptx and compiling with ptxas) have significant compilation times that makes the first call to just-in-time compiled operators slow
Matching precompiled performance requires identifying the number of kernels that can be just-in-time compiled
The next sections elaborate on each of the challenges and how the jiterator addresses them (or plans to address them).

A TENSORITERATOR-COMPATIBLE INTERFACE

PyTorch already has a lot of architecture to facilitate writing elementwise kernels using dispatch macros and TensorIterator. The relevant snippets for the CUDA implementation of gcd (greatest common divisor) after a TensorIterator object has been constructed are:

AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
  gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
    return calc_gcd(a, b);
  });
});
    
template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) {
  scalar_t a = ::abs(a_in);
  scalar_t b = ::abs(b_in);
  while (a != 0) {
    scalar_t c = a;
    a = b % a;
    b = c;
  }
  return b;
}

These snippets describe gcd’s “functor” — the operation called in a TensorIterator kernel that takes care of tensors’ shapes, contiguities, and type promotion (which are encoded in the iter object). The TensorIterator interface abstracts a tremendous amount of complexity away for developers, making it straightforward to define the gcd operation as a functor of two scalars.

Now let’s look at the analogous snippets in the jiterated version of gcd:

AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
  jitted_gpu_kernel</*name=*/gcd_name,
                    /*return_dtype=*/ scalar_t,
                    /*common_dtype=*/ scalar_t,
                    /*arity=*/ 2>(iter, gcd_string);
});

const auto gcd_string = jiterator_stringify(
  template <typename T>
  T gcd(const T a_in, const T b_in) {
    T a = abs(a_in);
    T b = abs(b_in);

    while (a != T{0}) {
      T c = a;
      a = b % a;
      b = c;
    }

    return b;
  }
); // gcd_string

The code is pretty similar. The dispatch macro and functor are effectively the same (although the functor is now passed to a macro). The call to jitted_gpu_kernel() may look a little intimidating, but it just describes some properties of the gcd operation explicitly — like how many arguments the functor takes. To create the complete kernel the jiterator inserts the string representation of the gcd functor into a string representation of a TensorIterator kernel before shipping the entire string to nvrtc.

We think using the jiterator is comparable to writing precompiled CUDA kernels, and we’ve successfully ported some of our most complicated elementwise CUDA kernels to it.

COMPILATION TIME

When gcd was precompiled it took about 58 microseconds to run (on a devfair, with CUDA long tensors with 2^16 elements). The first time the jiterated gcd op is called, however, it takes over 2500 times as long to run! This happens because the first call to gcd has to just-in-time compile the kernel. Subsequent calls to gcd that can reuse that kernel take the same time as when gcd was precompiled, however.

Calls to JAX operations exhibit the same pattern, but frameworks like JAX aren’t intended for eagerly running operations like PyTorch is. Compilation time may be especially painful for users progressively writing and incrementally validating their scripts because every run triggers the same just-in-time compilations. To limit the impact of just-in-time compilation performance we’re only jiterating little used operators ahead of PyTorch 1.11. Better just-in-time compilation strategies are an interesting direction for Future Work (see below).

MATCHING PRECOMPILED PERFORMANCE (AFTER THE INITIAL COMPILATION)

While the first call to a just-in-time compiled operator is slow, subsequent calls can be as fast as if the operator were precompiled. The trick to this identifying the set of kernels we could generate in all possible scenarios and precompiling the dispatch to them.

For non-jiterated operations PyTorch compiles a fixed set of kernels when built. Elementwise TensorIterator kernels get six for every dtype they support — one for arbitrary strides, another for arbitrary strides with type promotion, another for contiguous inputs with type casting, and finally three vectorized kernels (with different vectorization lengths) for contiguous inputs without type casting. So if an elementwise operators supports two datatypes (like float and double) it compiles 12 kernels, and if it supports four it compiles 24, etc. For elementwise binary operations that support scalars three times as many kernels are precompiled (36 instead of 12, 72 instead of 24)! That’s a lot of kernels, and it’s easy to see why especially complicated kernels, like the zeta or erfcx operations, might take a long to compile traditionally.

The jiterator can compile the same set of kernels, and because this set is known when PyTorch is built we can construct the runtime mapping from all inputs to all possible kernels at build-time, too, making dispatch to cached just-in-time compiled kernels as fast as non-jiterated dispatch. Callgrind benchmarking shows that cpu instruction count is actually slightly smaller when a jit-compiled kernel is launched than when the same operation is launch the regular way (182565 instructions for 10 iterations with jiterator vs 192966 instructions without). Timing the cpu overhead of launching a kernel is very noisy, but shows results around 5 us in both cases. A more complete performance analysis that measures the performance of the GPU kernel shows there are some input shapes where the jiterator’s kernels are a shade faster than eager kernels, and others where they’re a shade slower, but they’re effectively the same. P474088187 shows achieved bandwidth for the gcd operation for the different sizes and broadcasting patters with regular and jitted implementation.

RECAP & FUTURE WORK

The jiterator is PyTorch’s first just-in-time compilation system that runs elementwise CUDA eager operations. It provides a TensorIterator-compatible developer interface, can just-in-time generate the same set of kernels we precompile today, and once those kernels are compiled delivers performance comparable to precompiled operators.

Today a dozen operations are jiterated — digamma, trigamma, lgamma, gcd, lcm, i0, i0e, i1, i1e, ndtri, erfcx and zeta. And delaying the compilation of their CUDA kernels has reduced the CUDA context size of PyTorch’s native kernels by 7% (37.5 MB) since December. For builds with dynamic linking, about 2/5th of PyTorch’s initial CUDA context is native operations, however, so this leads to a more modest 3% improvement overall (see this workplace post for a more detailed breakdown of PyTorch’s CUDA context). While working on the jiterator we also identified additional opportunities to reduce our CUDA context size, leading to a total reduction of 62 MB (12%) in context size due to PyTorch native kernels, which translates to 7% improvement (60 MB) in overall context size.

Cold build times also improved. On the commit preceding the jiterator, a 20 core cold build took 17 minutes wall time, and 318 minutes user time. Cold builds now take 14 minutes wall time and 277 minutes user time (a decrease of ~17.5% in real time). Of course other commits over the past month may also have impacted build time.

Looking ahead, the biggest challenge for just-in-time compilation of eager operators remains just-in-time compilation speed. One idea to address this issue is creating a cache that’s persistent across processes, another approach may be alternate compilation strategies (Zach discusses some ideas in his most recent MinTorch post).

We also plan to continue extending and tuning the jiterator while working on just-in-time compilation speed. There are a few types of elementwise operations the jiterator doesn’t support yet, and PyTorch doesn’t have the headers nvrtc needs for complex math (our colleagues at NVIDIA are working on this problem, too, for nvFuser). Supporting reductions and even making it easy to jit any CUDA kernel is interesting, too.

If you have questions, comments or are interested in working on some of these issues please reach out by commenting below or contacting us directly!

6 Likes

First of all it is very cool

One idea to address this issue is creating a cache that’s persistent across processes

Question, doesn’t NVRTC caches kernels similarly to OpenCL cache?

In any case I used sqlite as cache database for speeding up operations for non-nVidia GPUs for OpenCL Kernel caches in dlprimitives/pytorch dlprimitives-opencl implementation: https://github.com/artyom-beilis/dlprimitives/blob/master/src/binary_cache.hpp

it works fantastically and greatly increases startup times of AMD and Intel devices

The relevant snippets for the CUDA implementation…

Would you consider thinking making it expandable this to non-CUDA runtime-API in future (OpenCL I work on)

I apply somewhat similar technique to implement operators in fast manner for OpenCL.

For example:

broadcasting
add: https://github.com/artyom-beilis/pytorch_dlprim/blob/master/src/pointwise_ops.cpp#L201

reduction + broadcasting
mean: https://github.com/artyom-beilis/pytorch_dlprim/blob/master/src/pointwise_ops.cpp#L362

If you try to consider making it extendable to - non CUDA devices it may significantly help in implementation for other backends.

First of all it is very cool

Thanks!

Question, doesn’t NVRTC caches kernels similarly to OpenCL cache?

There’s may be some caching that occurs but we also want to avoid string manipulation and compilation to ptx.

In any case I used sqlite as cache database for speeding up operations for non-nVidia GPUs for OpenCL…

Cool. I’m curious how long it takes to construct a key and look up a kernel in such a system? We are thinking about cache designs now and focusing on minimizing latency.

Would you consider thinking making it expandable this to non-CUDA runtime-API in future (OpenCL I work on)

Yes, but it would require work from someone more familiar with those device types.

There’s may be some caching that occurs but we also want to avoid string manipulation and compilation to ptx.

I don’t know if it is the same for nvrtc but for OpenCL - I found that caching of ptx actually slows a process down a little since for OpenCL nVidia caches binaries based on source code. But I don’t know if it does the same for nvrtc

Cool. I’m curious how long it takes to construct a key and look up a kernel in such a system? We are thinking about cache designs now and focusing on minimizing latency.

I actually have two “keys” system:

  1. For cache on disk where I calculate sha1 of entire source code, defines, compilation parameters, driver version, device name etc. This goes as DB key. It is quite fast but I didn’t measure since it is executed only once but it speeds up things significantly.

    I can make measurements of the entire process if this can help you.

  2. Is memory key - once I compile the kernel I keep it in memory. The key for this is small key name since it is valid only for this process and not related to possible changes between versions and the search is fast and does not require sha1 calculation.

Only in case there is a miss I build the source code build sha1 and go to sqlite3 DB to query. But it is done only once per process.

For example, when I call

 dlprim::core::pointwise_operation_broadcast({x0,x1,x2},{y0},{w0},
                              "y0 = x0 + w0 * x1 * x2;",
                              getExecutionContext(self));

The key is y0 = x0 + w0 * x1 * x2; + several parameters, but other things can be used like file name + line number etc. This is fast especially since execution is asynchronus you don’t bottleneck the GPU - since you’ll be able to enqueue much faster than run some heavy kernels like convs or gemms.

Good you asked about times I did some measurements, see results below:

TL;DR: in average I get ~0.1ms per binary kernel on AMD and 0.14ms for nVidia’s PTX for fetching kernel from cache based on source code (i.e. compute full source code sha1 and run select and get binary blob)

In longer

I noticed since I keep LRU time stamp so when I need to clear cache I remove oldest items originally I updated it every fetch - that means write transaction and commit to the disk, it takes much more time.

When I changed update LRU only of time was more than 24h so I don’t update it frequently the time went back to 0.1ms. Also if LRU update is done average time is ~6.5ms since it is full ACID transaction (on SSD). LRU handling increases time by 8% in comparison to don’t LRU update at all.

Now there is a small difference in time between large and smaller kernels: for example pointwise kernels - ones you use JIT for - it is about 0.05-0.07ms. While larger gemm-conv merged kernels or winograd kernels take 0.10-0.15ms.

Notes:

  • time is measured from given source code to getting binary from cache.
  • tested on ResNet18, inference.
  • tested on AMD 6600XT GPU that has real binary cache and not PTX - so times may be affected by binary code size
  • nVidia 960 PTX caching times were slightly slower by ~35% and took in average 0.14ms most likely due to binary size
  • CPU i5-6600, mem-DDR3, disk M2 SSD, OS Ubuntu 18.04
Kernel Name LRU Update No LRU LRU On Demand Nvidia LRU On Demans
sgemm 8.66 0.14 0.14 0.23
bn_sums 6.27 0.09 0.09 0.10
bn_sums 5.83 0.07 0.07 0.09
bn_utils 6.40 0.08 0.09 0.09
scal 5.97 0.03 0.03 0.05
pooling 6.14 0.10 0.12 0.11
winograd_fwd 6.24 0.10 0.12 0.20
bn_sums 6.66 0.06 0.07 0.09
bn_sums 6.58 0.06 0.07 0.11
pointwise_broadcast_reduce 6.48 0.07 0.07 0.09
pointwise_broadcast_reduce 6.49 0.07 0.08 0.09
sgemm 6.91 0.19 0.13 0.29
bn_sums 6.24 0.07 0.08 0.09
bn_sums 6.25 0.06 0.06 0.09
sgemm 6.92 0.14 0.13 0.24
sgemm 7.33 0.15 0.22 0.22
sgemm 7.09 0.14 0.14 0.21
sgemm 7.25 0.15 0.14 0.24
sgemm 7.13 0.14 0.13 0.19
global_pooling 6.53 0.05 0.06 0.08
sgemm 6.92 0.14 0.17 0.16
random 5.90 0.08 0.09 0.07
pointwise 5.16 0.05 0.05 0.06

That’s interesting. Is the .1 milliseconds for loading the kernel from a file, or the added latency when calling a kernel already loaded into memory?

We actually just implemented a cache (and will have a post soon) but our time to load a kernel from a file is more like 1 millisecond than .1 milliseconds.

~0.1ms is for loading kernel from file. Once it is in memory I don’t access file. i.e. execution of this “pseudo-function”

binary_blob get_kernel_from_cache(source_code,parameters);

Where binary blob is memory binary code like std::vecotor<char> and sources are actual sources that need to be compiled if cache is missed.

Also note stuff like this may depend on disk, its speed, OS disk caching, size of cache, etc.

In general if you need some kind of structured data storage Sqlite3 is virtually as fast as accessing raw files + being fully ACID. It is a brilliant small DB.

I would like to know which version of rocm Jiterator supports starting with?