vision_maskrcnn has been failing in the PT2 benchmark suite for as long as I can remember, and part of the reason is that when we run it twice in eager mode, it gives different results. It turned out this is because it uses a few operators like upsample_bilinear and roi_align which have nondeterministic backwards. Fixing upsample_bilinear’s non-determinism turned out to be pretty easy: we have a Python decomposition which expresses the operator in terms of more elementary PyTorch operations which are (1) differentiable (so we get backwards for free) and (2) we do support deterministic execution for (forwards and backwards). (Of course, running this decomposition isn’t likely to be performant, but that’s what we have Inductor for!)
So I wanted to apply the same playbook roi_align. But unfortunately no one had written a loop-less, pure Python implementation of roi_align (and when I asked GPT-4 to write me one, it gave me an implementation that didn’t give the result. I, uh, declined to debug further.) I was also short on time and couldn’t justify spending an entire day on this, so I needed a strategy of writing this implementation that would not require too much heavy brainpower.
Fortunately, torchvision’s CUDA implementation of roi_align is pretty simple: a 130 LOC or so:
However, it is written in the usual style of CUDA kernels: it says how to compute a single output entry on an element-by-element basis. How to convert this into a loop-free Python implementation?
Enter first class dimensions! There are a lot of ways to think about first class dimensions, but one way I like to think about it, is that first class dimensions give you an easy way to write code in PyTorch that closely resembles code with explicit loops.
Here is the full roi_align implementation in Python using first class dims:
If you compare it side-by-side with the CUDA code, you’ll notice they are very similar. Here are some key differences in translation.
No need to reverse engineer per-dimension index by position. One of the first things the CUDA kernel does is extract individual indices from a linearized index into the output tensor:
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
Instead, with first-class dims we simply allocate a first class dim for each index. n and c we will be able to infer automatically by using them to index; ph and pw we bind explicitly to the output size.
n, c, ph, pw = dims(4)
ph.size = pooled_height
pw.size = pooled_width
Pointer offsets turn into indexing operations. In the CUDA code, we select out a particular row of a 2D tensor by computing a pointer offset from the base pointer.
const T* offset_rois = rois + n * 5;
rois is a (K, 5) tensor, so this gets you the n’th row. With first class dims, we can write this in the natural way:
offset_rois = rois[n]
Something similar happens with input:
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
becomes
offset_input = input[roi_batch_ind][c]
Basic computation transcribes as-is. The magic of first-class dims is that once you’ve indexed out of a tensor with a first class dim, you can write the rest of your computation as if you’re operating element-by-element. For example, in the original code, we have:
T roi_start_w = offset_rois[1] * spatial_scale - offset;
With first class dims, this transcribes into the identical:
roi_start_w = offset_rois[1] * spatial_scale - offset
In fact, I need to vectorize this computation over every offset_rois in rois, but first class dims takes care of implicitly vmap’ing the computation, so I don’t need to do any of the book-keeping itself.
Control flow becomes torch.where (or clamp.) If you have an if statement, where some of the elements are vectorized:
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
Just convert each internal assignment into a torch.where assignment, eliminating the explicit control flow:
y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
y_low = torch.where(y_low >= height - 1, height - 1, y_low)
y = torch.where(y_low >= height - 1, y.to(input.dtype), y)
Sometimes, if statements can be conveniently expressed in a more idiomatic way. For example:
if (y <= 0)
y = 0;
becomes
y = y.clamp(min=0)
Handling the ROIs. Everything ported swimmingly well, except for this part:
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
}
Hmm, that outer loop goes from 0 to roi_bin_grid_h. The standard playbook is to allocate a new first-class dim if you want to do this. But this doesn’t work:
iy, ix = dims(2)
iy.size = roi_bin_grid_h
ix.size = roi_bin_grid_w
because roi_bin_grid_h and roi_bin_grid_w are themselves tensors! In fact, roi_align internally has some non-uniform computation, where each CUDA thread may do a differing amount of work depending on how big the ROI it was assigned to do was. Unfortunately, first class dims doesn’t directly support this (we’d need some sort of nested tensor support), so I needed some alternate strategy.
Fortunately, I wasn’t trying to write a fast kernel, so I came up with a compromise: just do the compute over the entire input image, and then mask out the elements I didn’t care about. There were two places where I needed to mask. First, when I do bilinear interpolation, I index into the input tensor.
// do bilinear interpolation
T v1 = input[y_low * width + x_low];
T v2 = input[y_low * width + x_high];
T v3 = input[y_high * width + x_low];
T v4 = input[y_high * width + x_high];
If I don’t mask out indices that correspond to invalid ROI, I might trigger an illegal memory access. So in the new implementation I have to mask out the elements. (Though, on further reflection, the kernel may already be working hard to avoid out of bounds accesses.)
# do bilinear interpolation, but respect the masking!
def masked_index(y, x):
y = torch.where(ymask, y, 0)
x = torch.where(xmask, x, 0)
return input[y, x]
v1 = masked_index(y_low, x_low)
v2 = masked_index(y_low, x_high)
v3 = masked_index(y_high, x_low)
v4 = masked_index(y_high, x_high)
NB: here, xmask and ymask are
ymask = iy < roi_bin_grid_h
xmask = ix < roi_bin_grid_w
Additionally, the original CUDA code looping over iy and ix is summing all the values into output_val
. When I perform my summation, I must make sure to mask out all elements that wouldn’t actually have participated in the sum.
val = torch.where(ymask, val, 0)
val = torch.where(xmask, val, 0)
output = val.sum((iy, ix))
Conclusion. All-in-all, it took only 1.5 hours to do the port from start-to-end (with a few meetings in between!) And the forwards code worked on the first try!
$ python a.py
tensor(-426.5444)
tensor(-426.5444)
First class dims are an incredibly powerful tool for writing kernels in Python, and chances are, Inductor can do a pretty good job optimizing them (in our case, we would expect Inductor to be able to fuse away all of our intermediate tensors, leaving the only inefficiency from operating over the entirety of the input tensor per region, rather than only the pixels in the region of interest.) The next time you need to write a kernel with persnickety index-by-index compute in Python, give it a try!
Some ideas for future work. There were three pieces of infrastructure that would have helped quite a bit here:
- Masked tensors. I manually had to take care of masks, which is annoying when you have to mask out an indexing operation or beware of illegal data access. A masked tensor abstraction that works with first class dims probably could have eliminated this manual book-keeping.
- Nested tensors. If first class dims supported nested tensors, it would be feasible to create a first class dim whose size was a tensor (rather than a single int), indicating non-uniformity. When vectorized this would result in a nested tensor where the variable size ROIs were packed together into a contiguous buffer. This eliminates the necessity for padding (what we did in the kernel here), making the pointwise operations run faster. One downside, however, is that to pack the ROIs into a nested tensor, you would have to do a DtoH sync (so that you would know how big of a buffer to allocate here.)
- On-device loop higher order operator. The original torchvision kernel doesn’t bother allocating the buffer, because it never materializes the nested tensor; instead, it performs the summation directly on-device with a loop with data-dependent bound. If we had some way of expressing this concept directly in Python, we could likely generate a kernel in Triton that is just as good as the original. (This is not as useful for eager-only compute though, since you have to codegen this kernel for good performance.)