A pure Python implementation of roi_align that looks just like its CUDA kernel

vision_maskrcnn has been failing in the PT2 benchmark suite for as long as I can remember, and part of the reason is that when we run it twice in eager mode, it gives different results. It turned out this is because it uses a few operators like upsample_bilinear and roi_align which have nondeterministic backwards. Fixing upsample_bilinear’s non-determinism turned out to be pretty easy: we have a Python decomposition which expresses the operator in terms of more elementary PyTorch operations which are (1) differentiable (so we get backwards for free) and (2) we do support deterministic execution for (forwards and backwards). (Of course, running this decomposition isn’t likely to be performant, but that’s what we have Inductor for!)

So I wanted to apply the same playbook roi_align. But unfortunately no one had written a loop-less, pure Python implementation of roi_align (and when I asked GPT-4 to write me one, it gave me an implementation that didn’t give the result. I, uh, declined to debug further.) I was also short on time and couldn’t justify spending an entire day on this, so I needed a strategy of writing this implementation that would not require too much heavy brainpower.

Fortunately, torchvision’s CUDA implementation of roi_align is pretty simple: a 130 LOC or so:

However, it is written in the usual style of CUDA kernels: it says how to compute a single output entry on an element-by-element basis. How to convert this into a loop-free Python implementation?

Enter first class dimensions! There are a lot of ways to think about first class dimensions, but one way I like to think about it, is that first class dimensions give you an easy way to write code in PyTorch that closely resembles code with explicit loops.

Here is the full roi_align implementation in Python using first class dims:

If you compare it side-by-side with the CUDA code, you’ll notice they are very similar. Here are some key differences in translation.

No need to reverse engineer per-dimension index by position. One of the first things the CUDA kernel does is extract individual indices from a linearized index into the output tensor:

    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int c = (index / pooled_width / pooled_height) % channels;
    int n = index / pooled_width / pooled_height / channels;

Instead, with first-class dims we simply allocate a first class dim for each index. n and c we will be able to infer automatically by using them to index; ph and pw we bind explicitly to the output size.

    n, c, ph, pw = dims(4)
    ph.size = pooled_height
    pw.size = pooled_width

Pointer offsets turn into indexing operations. In the CUDA code, we select out a particular row of a 2D tensor by computing a pointer offset from the base pointer.

    const T* offset_rois = rois + n * 5;

rois is a (K, 5) tensor, so this gets you the n’th row. With first class dims, we can write this in the natural way:

    offset_rois = rois[n]

Something similar happens with input:

  const T* offset_input =
        input + (roi_batch_ind * channels + c) * height * width;


    offset_input = input[roi_batch_ind][c]

Basic computation transcribes as-is. The magic of first-class dims is that once you’ve indexed out of a tensor with a first class dim, you can write the rest of your computation as if you’re operating element-by-element. For example, in the original code, we have:

    T roi_start_w = offset_rois[1] * spatial_scale - offset;

With first class dims, this transcribes into the identical:

    roi_start_w = offset_rois[1] * spatial_scale - offset

In fact, I need to vectorize this computation over every offset_rois in rois, but first class dims takes care of implicitly vmap’ing the computation, so I don’t need to do any of the book-keeping itself.

Control flow becomes torch.where (or clamp.) If you have an if statement, where some of the elements are vectorized:

  if (y_low >= height - 1) {
    y_high = y_low = height - 1;
    y = (T)y_low;
  } else {
    y_high = y_low + 1;

Just convert each internal assignment into a torch.where assignment, eliminating the explicit control flow:

    y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
    y_low = torch.where(y_low >= height - 1, height - 1, y_low)
    y = torch.where(y_low >= height - 1, y.to(input.dtype), y)

Sometimes, if statements can be conveniently expressed in a more idiomatic way. For example:

  if (y <= 0)
    y = 0;


    y = y.clamp(min=0)

Handling the ROIs. Everything ported swimmingly well, except for this part:

  T output_val = 0.;
    for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
      const T y = roi_start_h + ph * bin_size_h +
          static_cast<T>(iy + .5f) * bin_size_h /
              static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
        const T x = roi_start_w + pw * bin_size_w +
            static_cast<T>(ix + .5f) * bin_size_w /

        T val = bilinear_interpolate(offset_input, height, width, y, x, index);
        output_val += val;

Hmm, that outer loop goes from 0 to roi_bin_grid_h. The standard playbook is to allocate a new first-class dim if you want to do this. But this doesn’t work:

iy, ix = dims(2)
iy.size = roi_bin_grid_h
ix.size = roi_bin_grid_w

because roi_bin_grid_h and roi_bin_grid_w are themselves tensors! In fact, roi_align internally has some non-uniform computation, where each CUDA thread may do a differing amount of work depending on how big the ROI it was assigned to do was. Unfortunately, first class dims doesn’t directly support this (we’d need some sort of nested tensor support), so I needed some alternate strategy.

Fortunately, I wasn’t trying to write a fast kernel, so I came up with a compromise: just do the compute over the entire input image, and then mask out the elements I didn’t care about. There were two places where I needed to mask. First, when I do bilinear interpolation, I index into the input tensor.

  // do bilinear interpolation
  T v1 = input[y_low * width + x_low];
  T v2 = input[y_low * width + x_high];
  T v3 = input[y_high * width + x_low];
  T v4 = input[y_high * width + x_high];

If I don’t mask out indices that correspond to invalid ROI, I might trigger an illegal memory access. So in the new implementation I have to mask out the elements. (Though, on further reflection, the kernel may already be working hard to avoid out of bounds accesses.)

    # do bilinear interpolation, but respect the masking!
    def masked_index(y, x):
        y = torch.where(ymask, y, 0)
        x = torch.where(xmask, x, 0)
        return input[y, x]

    v1 = masked_index(y_low, x_low)
    v2 = masked_index(y_low, x_high)
    v3 = masked_index(y_high, x_low)
    v4 = masked_index(y_high, x_high)

NB: here, xmask and ymask are

    ymask = iy < roi_bin_grid_h
    xmask = ix < roi_bin_grid_w

Additionally, the original CUDA code looping over iy and ix is summing all the values into output_val. When I perform my summation, I must make sure to mask out all elements that wouldn’t actually have participated in the sum.

    val = torch.where(ymask, val, 0)
    val = torch.where(xmask, val, 0)
    output = val.sum((iy, ix))

Conclusion. All-in-all, it took only 1.5 hours to do the port from start-to-end (with a few meetings in between!) And the forwards code worked on the first try!

$ python a.py

First class dims are an incredibly powerful tool for writing kernels in Python, and chances are, Inductor can do a pretty good job optimizing them (in our case, we would expect Inductor to be able to fuse away all of our intermediate tensors, leaving the only inefficiency from operating over the entirety of the input tensor per region, rather than only the pixels in the region of interest.) The next time you need to write a kernel with persnickety index-by-index compute in Python, give it a try!

Some ideas for future work. There were three pieces of infrastructure that would have helped quite a bit here:

  • Masked tensors. I manually had to take care of masks, which is annoying when you have to mask out an indexing operation or beware of illegal data access. A masked tensor abstraction that works with first class dims probably could have eliminated this manual book-keeping.
  • Nested tensors. If first class dims supported nested tensors, it would be feasible to create a first class dim whose size was a tensor (rather than a single int), indicating non-uniformity. When vectorized this would result in a nested tensor where the variable size ROIs were packed together into a contiguous buffer. This eliminates the necessity for padding (what we did in the kernel here), making the pointwise operations run faster. One downside, however, is that to pack the ROIs into a nested tensor, you would have to do a DtoH sync (so that you would know how big of a buffer to allocate here.)
  • On-device loop higher order operator. The original torchvision kernel doesn’t bother allocating the buffer, because it never materializes the nested tensor; instead, it performs the summation directly on-device with a loop with data-dependent bound. If we had some way of expressing this concept directly in Python, we could likely generate a kernel in Triton that is just as good as the original. (This is not as useful for eager-only compute though, since you have to codegen this kernel for good performance.)

Actually, there’s one more thing to be careful about. If you use advanced indexing (as is the case in offset_input = input[roi_batch_ind][c]), it’s important to delay this indexing as long as possible, until the point where you know the indices of all the other dimensions as well. Otherwise, in eager mode you will eagerly materialize all of the slices of input from offset_input and OOM.

Another interesting point is to think whether you can use broadcasting within advanced indexing to perform all the advanced indexing in one go. This helps reducing the number of reads and writes and use the cache more efficiently. Sometimes torchinductor is clever enough to do this itself, but sometimes it’s not.

Examples of this is when you have an outer product between two sets of indices, as you have here. Sometimes this suggest how to vectorize the whole operation and perform it via operations on the whole tensor.

Performing all the indexing in one go also helps with the generated backwards code. The backwards code for indexing is index_put_, which is often more efficient to do in one go rather than in 4 chunks.

Well, advanced indexing automatically does broadcasting, so I would guess that it is not so hard to make sure this happens automatically when you’re using first class dims. Something that I’m not too sure about is whether or not torchinductor always makes good decisions here though.

The final, optimized kernel, can be found here: Add deterministic, pure-Python roi_align implementation by ezyang · Pull Request #7587 · pytorch/vision · GitHub