PyTorch Sparse(GNN) Compiler RFC

Some updates for this channel:

I have developed an efficient index scatter reduce operation for GPUs using CUDA, and have integrated it into a PyTorch extension. I plan to open-source this repository in the near future. However, my ultimate goal is to contribute this work to Triton to enable automatic code generation, a milestone that still challenges me.

The frontend now is introduced as below:

dst = index_scatter_reduce(dim, index, src, reduce, sorted=True) → Tensor

Unlike the traditional scatter_reduce operation, we mandate that dst and src have the same number of dimensions. The index is constrained to just one dimension. It is also required that index.size(dim) <= src.size(dim) for the specified dim argument.

For a 3-D tensor with reduce="sum", the output is calculated as follows:

dst[index[i]][j][k] += src[i][j][k]  # if dim == 0
dst[i][index[j]][k] += src[i][j][k]  # if dim == 1
dst[i][j][index[k]] += src[i][j][k]  # if dim == 2

Additionally, we have integrated a sorted flag to optimize the index_scatter_reduce kernel’s performance. A sorted index enhances processing locality and minimizes atomic operations, thereby substantially improving the parallel processing efficiency for both CPUs and GPUs. This feature can be implemented from the revised front end, as shown in the PyG framework code (TODO). A sorted index adheres to a non-decreasing order. For instance, index = [0, 0, 0, 1, 1, 2] qualifies as a sorted index when sorted=True.

I am also more than glad to share some results of our kernel. Tested with carry feature=128 on H100 GPU.

Some thoughts on the front end:

I took a try using Jax and found it can handle the carry option well by specifying the dimensions of updating and scattering. Also, it can be optimized under some cases by setting up an option indices_are_sorted, Ref

Here is a code snippet about using jax.scatter_add

import jax.numpy as jnp
import jax.lax

# High-dimensional target array to be updated
target = jnp.zeros((4, 3))

# High-dimensional indices where updates will be added
# Suppose we want to update the first and third row
indices = jnp.array([[0], [2]])

# High-dimensional updates to add
# Each update corresponds to a row mentioned in indices
updates = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32)

# Define dimension numbers
dimension_numbers = jax.lax.ScatterDimensionNumbers(
    update_window_dims=(1,),  # Update is applied along the second axis (columns)
    inserted_window_dims=(0,),  # We're inserting along the first axis (rows)
    scatter_dims_to_operand_dims=(0,)  # Mapping from indices to target dims
)

# Perform scatter add operation
result = jax.lax.scatter_add(target, indices, updates, dimension_numbers, indices_are_sorted=True)

print("Result:\n", result)

@jansel Do you have any suggestions on the front-end design? How to make this general yet efficient?

The simplest option would just be to add a new operation to PyTorch. You can upstream your op to the main PyTorch if it would be useful to others. If you define a device="meta" version of your op (to compute output sizes), it should work with torch.compile. cc @zou3519 on custom ops.

Are there any generalizations of this op you are thinking of? Other ops that would follow a similar pattern?

To be honest, I consider the index scatter reduce operation to be quite general and applicable across various domains, such as graph analysis and point-cloud voxel processing. My preferred approach here would not involve introducing a new operation, especially since I’ve identified the core issue to lie in PyTorch’s definition of scatter_reduce.

PyTorch defines the scatter_reduce operation as follows:

self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

However, this definition creates a discrepancy at the front end. For instance, in GNNs, what we actually aim to achieve can be described as:

self[index[i]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[j]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[k]] += src[i][j][k]  # if dim == 2

Therefore, to perform GNN aggregation in the manner PyTorch prescribes, it becomes necessary first to broadcast the tensor and then apply scatter_reduce. This approach restricts the optimization potential for the torch inductor, as the ‘sorted index’ becomes indistinguishable post-broadcast of the index tensor. In contrast, the JAX framework accommodates this feature more naturally due to its more generalized front end. This means operations whether like index scatter_reduce (1D index) or scatter_reduce (multi-D index) can be succinctly expressed with the jax.lax.scatter(reduce) operation.

However, this may violate the forward compatibility since PyTorch has used this op for all the previous versions.

Can’t the broadcast be achieved with zero extra kernels by having index.stride(n) == 0 in the unused dimensions?

I think you are looking for torch.Tensor.index_add_ — PyTorch 2.2 documentation

More generally, in PyTorch we already have a (non-public lol) segment_reduce under torch._segment_reduce. We could:

  • Make this function public, write proper docs, clean its API if needed, etc.
  • Have codegen for this function. This can already be done in triton and in our IR, as our IR supports multiple input reductions.

Would this give you enough support to implements your models?

1 Like

Thank you for the information! It’s exactly what I’ve been searching for. I just reached out to the PyG team to inquire why they use broadcast followed by scatter_reduce instead of index_reduce. PyG’s code Ref

I’ve previously considered segment_reduce, but I don’t view it as critically important now. This is because introducing a sorted_index tag to the index_reduce function effectively makes it equivalent to segment_reduce, but with greater versatility.

I’m excited about the possibility of realizing code generation for index_reduce with sorting! However, I’m intrigued by the following statement:

Based on my experience, Triton does not yet support the sorted scan function, and the torch IR does not accommodate such a reduction pattern. Could you provide further clarification on this issue and possibly share more code references (e.g. code link to torch’s repo)? Thank you!

What I meant is that Triton cannot sort, but if the input is sorted, triton (and Inductor’s IR) support reductions with multiple inputs and outputs. In particular, we could codegen the reduction that Jason described above, similar to how we codegen welford’s algorithm for var_mean.

Even cooler, when WIP: Add higher order associative scan operator by peterbell10 · Pull Request #119430 · pytorch/pytorch · GitHub lands (and we implement it also for reductions), you should be able to do this yourself without having to touch inductor’s IR.

1 Like