Some updates for this channel:
I have developed an efficient index scatter reduce operation for GPUs using CUDA, and have integrated it into a PyTorch extension. I plan to open-source this repository in the near future. However, my ultimate goal is to contribute this work to Triton to enable automatic code generation, a milestone that still challenges me.
The frontend now is introduced as below:
dst = index_scatter_reduce(dim, index, src, reduce, sorted=True) → Tensor
Unlike the traditional scatter_reduce operation, we mandate that dst
and src
have the same number of dimensions. The index is constrained to just one dimension. It is also required that index.size(dim) <= src.size(dim)
for the specified dim
argument.
For a 3-D tensor with reduce="sum"
, the output is calculated as follows:
dst[index[i]][j][k] += src[i][j][k] # if dim == 0
dst[i][index[j]][k] += src[i][j][k] # if dim == 1
dst[i][j][index[k]] += src[i][j][k] # if dim == 2
Additionally, we have integrated a sorted
flag to optimize the index_scatter_reduce kernel’s performance. A sorted index enhances processing locality and minimizes atomic operations, thereby substantially improving the parallel processing efficiency for both CPUs and GPUs. This feature can be implemented from the revised front end, as shown in the PyG framework code (TODO). A sorted index adheres to a non-decreasing order. For instance, index = [0, 0, 0, 1, 1, 2]
qualifies as a sorted index when sorted=True.
I am also more than glad to share some results of our kernel. Tested with carry feature=128 on H100 GPU.
Some thoughts on the front end:
I took a try using Jax and found it can handle the carry
option well by specifying the dimensions of updating and scattering. Also, it can be optimized under some cases by setting up an option indices_are_sorted
, Ref
Here is a code snippet about using jax.scatter_add
import jax.numpy as jnp
import jax.lax
# High-dimensional target array to be updated
target = jnp.zeros((4, 3))
# High-dimensional indices where updates will be added
# Suppose we want to update the first and third row
indices = jnp.array([[0], [2]])
# High-dimensional updates to add
# Each update corresponds to a row mentioned in indices
updates = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32)
# Define dimension numbers
dimension_numbers = jax.lax.ScatterDimensionNumbers(
update_window_dims=(1,), # Update is applied along the second axis (columns)
inserted_window_dims=(0,), # We're inserting along the first axis (rows)
scatter_dims_to_operand_dims=(0,) # Mapping from indices to target dims
)
# Perform scatter add operation
result = jax.lax.scatter_add(target, indices, updates, dimension_numbers, indices_are_sorted=True)
print("Result:\n", result)
@jansel Do you have any suggestions on the front-end design? How to make this general yet efficient?