How does advanced indexing work when I combine a tensor and a single index

How does advanced indexing work when I combine a tensor and a single index

Without looking, what does this program do?

x = t.tensor([[1,2],[3,4]])
m = t.tensor([True,True])
print(x[m][0])
print(x[m,0])

Answer below the fold.

I have always found it difficult to understand what exactly advanced indexing does in more complicated cases. But recently, while working on understanding this bug in our advanced indexing implementation (Fancy indexing bug when combining masks with indexes · Issue #71673 · pytorch/pytorch · GitHub) I came up with a new explanation for how it all works, and I want to share it with you!

First, it’s helpful to recap how advanced indexing works in simple cases. Advanced indexing refers to what happens when you pass a tensor (rather than just a number or slice) to the index of a tensor. What indexing occurs depends on what you pass as the argument:

  • If you pass a boolean tensor, it acts as a mask: you include rows which you have True, and drop rows which are False. The dim of the tensor doesn’t change, but the size at the dimension you are masking on may go down (because you filtered rows out).
  • If you pass an int tensor, it acts as indices: you include rows that are specified by the numeric indices. Once again, the dim doesn’t change (in fact, it can increase, if you pass a 2D or greater dim index tensor), but the size at the dimension you are indexing on may go up or down.

This is enough to explain the first print:

>>> x[m]  # m says to include everything, so we do!
tensor([[1, 2],
        [3, 4]])
>>> x[m][0]
tensor([1, 2])

But what about the second print?

>>> x[m,0]
tensor([1, 3])

I previously had thought that x[i,j] == x[i][j] (e.g., I could think of tupled indices as just applying the indices one-by-one), but this clearly is not the case here. In Numpy’s advanced indexing documentation, it says in the tupled case “Advanced indices always are broadcast and iterated as one” (Indexing on ndarrays — NumPy v1.23.dev0 Manual). What the heck does this mean? Essentially, the “0” here doesn’t mean "take the 0th row of x[m]", instead, it means “for each row that was selected by the boolean masking operation, take its 0th column.” We should think of each index in the tuple as referring to a separate dimension on the tensor: m filters the first dimension (rows), while 0 selects on the second dimension (columns). It is compositional!

Is there a way to decompose x[m,0] into simpler operations? In fact, the answer is yes! And this is how advanced indexing is implemented internally. The algorithm goes like this:

  1. First, do view operations.
  2. Finally, do the advanced indexing (with tensors) all in one go.

So all we need to do is reorder the operations to select 0th columns first (a view), and then do the indexing. Fortunately, indexing syntax has us covered: x[:,0][m] is equivalent. So the general algorithm goes like this:

  1. Figure out what dimension each view operation corresponds to, and do a single mega-view operation directly on those dimensions (without indexing). E.g., x[t1, i1, t2, i2] becomes y = x[:,i1,:,i2] if t1 and t2 are 1D tensors. The amazing thing about strides is that it is always possible to do this view operation.
  2. Do a single mega-indexing operation with all of the index tensors, e.g., y[t1, t2].

Compositionality restored!

P.S. The bug in the issue is that when an advanced indexing tensor is not 1D, we don’t properly offset the view operation indices appropriately. Subtle.

@rgommers points out NEP 21 — Simplified and explicit advanced indexing — NumPy Enhancement Proposals as a really good resource on advanced indexing.