PyTorch Sparse(GNN) Compiler RFC

:rocket: The feature, motivation and pitch

Torch.compile has been included in Pytorch 2.0 or late versions and has shown that it is a powerful way to speed up Pytorch code greatly. Pytorch inductor has pioneered graph-mode execution, significantly accelerating a spectrum of machine learning applications, including CNNs, Transformers, and Graph Neural Networks (GNNs). Despite these advancements, the end-to-end optimization for PyG-based GNNs remains suboptimal. The root of this inefficiency lies in the current implementation of the fused scatter_add kernel, e.g. fused_index_new_zeros_scatter_add_0. Also, regarding profiling results, the fused scatter_add kernel has accounted for over 80% of the running time of typical GNN models. The code provided here shows that codegen of fused gather scatter will lower atomic operations, which significantly affects the performance. This proposal sets forth a strategy to refine this mechanism, focusing initially on CPU backends due to their less complex thread hierarchy than GPUs and the current limitations of the Triton backend’s warp-level specification. Check here for the triton limitation for Sparse Kernel Codegen.

Alternatives

Option 1

The first option is to develop an independent compiler pass specifically for sparse tensors. However, this approach is challenging, given two major reasons. First of all, the backend of the Inductor scheduler is majorly designed for pointwise kernel fusion. Secondly, the input for SparseTensor can vary significantly, both in terms of the programming interface and the format used. For instance, PyTorch Geometric (PyG) employs a proprietary SparseTensor input, which differs from the native sparse tensor format utilized by PyTorch, as detailed in the official PyTorch documentation.

Option 2

Another alternative method involves registering the sparse tensor and sparse SpMM tensor within torch.ops. The Inductor would then default to calling an external kernel for SpMM-related operations (similar to segment_reduce). This approach has its downsides too, as it may diminish opportunities to fuse kernels based on SpMM-like or scatter_reduce operations.

Additional context

Objective:

The primary goal is to enhance the PyTorch inductor’s capability for optimizing scatter-based GNN end-to-end models, specifically targeting CPU backends. The ideal codegen result will be similar to CSR part of aten::scatter_reduce kernel.

Expected Performance:
These enhancements are expected to bridge the gap between the current and optimal performance for GNNs, thereby unlocking new efficiencies and potential within PyTorch’s framework. Technically. this project will be targeted at:

  • Codegen result will achieve the kernel-wise result proposed by this code, which showed the performance speedup from 4-7x.
  • End2end result will meet the expectation of 2-3x, considering the breakdown profiling of mainstream GNN model.

Techniques:

We will be working on the following new ideas:

  • Pattern Matching for Scatter-Based Operations:
    Investigation reveals that kernels involving scatter operations devolve to atomic-based IRs, which lead to subpar code generation. A pattern-matching strategy will be developed to identify and address these inefficiencies. Note that this pattern match happens in inductor/scheduler.py. We will add a sparse scheduler pass after the fusion process, filtering the scheduler node that needs to be optimized. Below is the possible code for reference.
self.create_foreach_nodes()
self.topological_sort_schedule()
self.logged_slow_fusion = set()
self.fuse_nodes()
self.sparse_aware_nodes() # sparse specific pass
  • Preprocessing for Sparse Formats:
    Recognizing the unique requirements of GNNs, we will implement a detection mechanism for input data indicative of sparsity. Such data will be preprocessed and converted into Compressed Sparse Row (CSR) format, making it more amenable to optimization for sparse operations.
  • IR Transformation:
    We propose to transform the conventional schedule nodes into new, sparse-specific schedule nodes within the PyTorch inductor’s IR. This will facilitate the generation of more efficient execution plans for sparse data. The code below shows a possible IR using inductor’s primitive:
buf1.group.device = cpu
buf1.group.iteration = ((nodes, feats), ())
buf1.sizes = ([nodes, feats [])
buf1.mutations = ['buf0']
class buf1_loop_body:
    var_ranges = {z0: nodes, z1: feats}
    index0 = z0
    index1 = z0 + 1
    index2 = feats*indirect1 + z1
    index3 = feats*z0 + z1
    def body(self, ops):
        start_index = self.get_index('index0')
        start = ops.load('csr_ptr', get_index)
        end_index = self.get_index('index1')
        end = ops.load('csr_ptr', get_index)
        out_idx = self.get_index('index3')
        acc = ops.load('C', out_idx) # acc = 0
        for i in range(start, end):
            set_indirect1 = self.set_indirect1(i)
            col_index = self.get_index('indirect1')
            col = ops.load('csr_ind', col_index)
            B_index = self.get_index('index2')
            B_val = ops.load('B_index')
            acc = ops.add(acc, B_val)
        store = ops.store('buf1', out_idx, acc, None)
        return store

  • Sparse-Aware Code Generation:
    Leveraging the modified IR, we will develop a code generation engine capable of producing highly optimized code for executing sparse operations in CSR format on CPU backends. The ideal final code can be checked here.

cc @mingfeima @yanbing-j @xinchen9 @zhang677

1 Like

Thank you for the suggestion. We would definitely be open to contributions here to improve inductor’s codegen. I think better support for sparsity would be awesome.

Do you think this speedup could be achieved through more generic changes to how scatter_add is implemented in inductor? (There are a number of different algorithms that could avoid the need for atomic_add.) Or is this a case where we need to use domain specific knowledge about the SparseTensor representation? I’d lean towards doing things more generically, if we can do it in a way that gets good performance.

@davidberard98 and @aakhundov have made some progress mapping some jagged tensor sparse representations to inductor, so their learnings might be useful in this case.

@zou3519 is looking at better custom kernel support in inductor, if you decide to go the extern kernel route.

What do you need from us to unblock you? Do you feel like you have enough understanding to prototype some of these things? Where do you need help?

1 Like

Thanks for the suggestions! Curious about this part:

Recognizing the unique requirements of GNNs, we will implement a detection mechanism for input data indicative of sparsity.

Do you mean that the input data will have specific non-torch.Tensor format that will be easy to detect (e.g., PyG’s or PT’s sparse tensor formats), assuming that it’s made its way to the TorchInductor? Or will the sparse data rather be represented as vanilla torch.Tensors, with the recognition based on some (domain-specific) semantics of how the data is handled?

Thanks for your reply! I am thrilled to know PyTorch is interested in promoting a compiler for Sparse workload. Sorry for my late reply these days. First of all, let me clarify my starting point:

It all starts from the kernel fusion of GAS model, which is fundamental in graph processing. To be specific, I list the code here:


# Basic "Gather-Apply-Scatter" patterns commonly used in PyG:
def gather_scatter(x, edge_index, reduce="sum"):
    row, col = edge_index
    x_j = x[row]
    return scatter(x_j, col, dim_size=x.size(0), reduce=reduce)

The issue that blocks me is that while the GAS model can be fused by the Torch Inductor, it currently reduces to an atomic_add operation. Although various algorithms exist to circumvent atomic_add, in this context, eliminating atomic operations seems feasible only through data transformation, particularly because the edge_index might be unsorted. For example, the edge_index can be created as

    num_nodes, num_edges = 10_000, 200_000
    edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)

In this code, the edge_index indicates that different nodes will be reduced to a single node. However, since we don’t know which specific nodes correspond to the reduction, the only way is to use atomic_add.

If we do the data transformation in the beginning, for example, we sort the edge_index from the beginning, and then transform the GAS model to gather_segment_reduce: (And I think this would be easily handled by PyG’s frontend, a code script similar to this principle here)

def gather_segment_reduce(x, edge_index, reduce="sum"):
    row, col = edge_index # sorted with col
    x_j = x[row]
    return segment_reduce(x_j, col, reduce=reduce)

In this way, since we have sorted the second dimension of edge_index, we can reduce the atomic operations to the largest extent. This is similar to the code brought by torch-scatter’s segment_coo function.

However, we found several challenges for the compiler under this scenario:

  1. PyTorch’s torch.segment_reduce function is designed to work with 1D tensors. It performs segment-wise reduction operations based on a corresponding 1D indices tensor. If x_j is larger than 1-D, e.g. 2-D tensor, this function will fail.
  2. segment_reduce can’t be fused with gather/index operator, i.e. x_j = x[row].
  3. segment_coo proposed by PyG isn’t compatible with the torch inductor.

Thanks for your quick response! Let me explain my idea in the following sections:

Jagged Tensor

Jagged tensor is a fundamental data structure, widely used in NLP and recommendation systems (like DLRM).
Here is an example of how to construct jagged tensor by using torchrec.sparse:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, offsets=offsets)

To express vividly, we convert it to a jagged view

[[1.0, 2.0],
[],
[3.0],
[4.0],
[5.0]
[6.0, 7.0, 8.0]]

As we can see, this jagged tensor has a dimension whose slices may be of different lengths.

Sparsity within edge_index using torch.Tensor

Let’s create a simple example here

row = torch.Tensor([0, 1, 2, 3, 3, 4, 2, 4, 5, 2])
col = torch.Tensor([1, 2, 3, 1, 5, 2, 4, 3, 3, 1])
edge_index = torch.stack([row, col], dim=0)

How to detect this edge_index is sparse?
This edge_index can be translated to a sparse matrix, noticing that each column [0, 1], [1, 2], [2, 3], [3, 1], [3, 5], [1, 2], [2, 4], [4, 3], [5, 3], [2, 1], represents an edge in the graph. This format is commonly used in graph neural network frameworks like PyTorch Geometric to represent graph structures. To detect whether it is sparse, we can use the following equation

sparsity = num_edges / (num_nodes * num_nodes)

Usually, when sparity is less than 0.05, we consider the workload here unsuitable for dense representation and will use SpMM rather than GEMM.

Back to the example given by PyG

num_nodes, num_edges = 10_000, 200_000
edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)

The estimated sparsity here is 0.002, which, though quite small, aligns with the typical range observed in graph problems.
Note here is an estimation since randint can’t ensure every pair of nodes are mutually different.

Relationship between jagged tenor and edge_index

I appreciate @jansel 's comment about doing things more generically, and I am eager to promote a sparse compiler in a more general way, rather than using domain-specific knowledge. So can we convert edge_index to a jagged tensor? The answer is yes.

First of all, let us sort the edge_index example I gave above, and I think it can be easily acquired by adding a tag in PyG @rusty1s. Since we need to perform reduction along the column dimension, our edge_index needs to be sorted along the column axis. The result should be:

row = torch.Tensor([0, 2, 3, 1, 4, 2, 4, 5, 2, 3])
col = torch.Tensor([1, 1, 1, 2, 2, 3, 3, 3, 4, 5])
edge_index = torch.stack([row, col], dim=0)

Then similarly, we can set a jagged tensor, which is indeed a CSR representation

row = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 0, 3, 5, 8, 9, 10])
edge_index = JaggedTensor(values=row, offsets=offsets)

How do we understand the sparsity and jagged features?

From my understanding, sparsity can be broken down into two aspects: Take a sparse matrix as an example.

  • Firstly, the elements in different rows of the matrix are distinct, which essentially reflects its jagged nature.
  • Secondly, the spacing between elements in the same row is irregular and discrete, a characteristic I would describe as ‘jumpiness’.

Current hardware, whether CPUs or GPUs, fundamentally cannot address this ‘jumpiness’ due to their architecture. This is because the issue of ‘jumpiness’ involves accessing non-continuous memory addresses, a limitation of current architectures. However, resolving the jagged issue already addresses a significant part of the sparsity problem. This is because solving the jagged issue can effectively tackle the atomic operations problem associated with unordered reductions, and these atomic operations are inefficient on both CPUs and GPUs. Therefore, my proposal primarily aims to address this atomic issue.

  1. In theory we should be able to extend torch.segment_reduce to support 2D. cc @albanD @mikaylagawarecki @serhaty on how hard that would be.

  2. Inductor should be able to fuse gather/index operations automatically. I think the real issue here is segment_reduce is a fallback op:
    https://github.com/pytorch/pytorch/blob/275403be165428d727dcd8e244d0ef05a4525be8/torch/_inductor/lowering.py#L2209
    Which means we never implemented a lowering for it. If we instead implemented an inductor lowering for segment_reduce (should be similar to scatter_reduce_), then I’d expect the fusion to happen.

  3. I think it should be possible to implement an inductor lowering for segment_coo. This can be done out of pytorch core by calling register_lowering.

The hardest codgen problem I see is taking advantage of the sorted indices. The inductor lowerings for the above ops that map to scatter_mode="atomic_add" would be easy to write, however they won’t be performant for data distributions with contention.

We could add a new mode like scatter_mode="sorted_add", then change the codegen to do some thread-local accumulation before issuing a smaller number of atomic_add ops. As you mentioned this would be easier on CPU than in Triton. Triton might be missing some features needed for this.

@peterbell10 – do you think your work on adding scan to Triton could be used for this application?

1 Like

I am glad we reached a consensus on this problem, where the hardest challenge lies in code generation.

After deep contemplation and discussion with @wang-y-z, who has experience in fast codegen for SpMM using Triton, I find implementing a “sorted_add” IR impractical for several reasons:

  1. “atomic_add” is a more versatile approach for calculating scatter_add-related problems.
  2. “atomic_add” typically underperforms in graph-related problems or highly sparse data, primarily because many inputs converge to the same output address, creating significant contention.
  3. Writing “sorted_add” in CUDA or OpenMP is complex, and it significantly deviates from Triton’s current “scan” functionality. Adding this feature to Triton presents considerable challenges.

Given these factors, I propose a simpler yet effective solution, akin to adding passes at the FX level, e.g., fx_passes of inductor. Similar to our exploration in fused attention, we can fuse the GNN-GAS pattern in a pattern-matching manner. This would involve substituting the original pattern with a fused index_select_scatter_reduce operator. Unlike the inductor’s elementwise scheduler, which lowers codegen to entirely atomic operations, this operator would be designed to minimize atomic operations by recognizing the sorted feature of the reduction axis. This approach also circumvents the complex format transformation methods (like the CSR conversion demonstrated in the first post).

A bit about myself: I previously worked at Nvidia optimizing GNN operators, and now I am collaborating with Intel to enhance CPU performance for GNN workloads. Inspired by the PyTorch 2.0 compiler, I believe that going beyond merely accelerating kernel-wise performance for specific hardware is crucial. I will dedicate myself to developing such a code demonstration later, providing a performance comparison with the inductor’s original codegen.

2 Likes

Pattern matching to a custom kernel, which I believe is what you are proposing, could solve the immediate problem. That seems easy to implement, and there are other example of that in the codebase. The only disadvantage there is we couldn’t automatically fuse other things into it, and it is less reusable. I’d welcome a PR to add something like that.


I don’t think the sorted_add approach would be that hard to codegen either. The “scan” would be something like:

for i in range(len(indices)-1):
    if indices[i] == indices[i+1]:
       values[i+1] += values[i]
       values[i] = 0

this could operate locally on a single Triton block. (Or within the block assigned to a single CPU thread.) It doesn’t need to be global.

then you would do:

tl.atomic_add(ptr + indices, values, mask=mask & (values != 0))

this would dramatically reduce the number of atomic_adds needed, since most of them would be masked away and added together locally inside the kernel.

The scan that exists in Triton today couldn’t implement that, because it doesn’t support multiple inputs and doesn’t have a “carry” output (like jax.lax.scan) – but I could imagine extending it to do so.

Thanks for your reply! However, I disagree with your interpretation of the scan pattern.

From my perspective, a common case of scan pattern under a sorted index should be something like

res[indices[0]] = values[0]
for i in range(len(indices)-1):
    if indices[i] == indices[i+1]:
        res[indices[i+1]] += values[i+1]
    else
        res[indices[i+1]] = values[i+1] 

For example, consider the index array:

[0, 0, 1, 1, 1, 2, 2]

And the corresponding values array:

[a, b, c, d, e, f, g]

Applying the scan pattern, we would obtain the result:

[a+b, c+d+e, f+g]

I’m unsure why you set values[i]=0 in your example, as this would alter the original values array. Although your code snippet captures the reduced numbers, it requires reindexing to achieve the final result. What you end up with is:

[0, a+b, 0, 0, c+d+e, 0, f+g]

This outcome seems counterintuitive since it doesn’t correctly store the result. Therefore, I believe that the tl.atomic_add extension you proposed does not accurately represent the scan pattern. The main issue with Triton’s codegen here is the data-related mask. The mask should be something akin to indices[i+1] == indices[i], which contradicts the design principles of Triton primitives.

The scan I am talking about would be what happens inside the Triton block, without touching memory. The assigning to 0 is only there to allow masking away the atomic_add, which would reduce contention.

You start with a Triton block of XBLOCK elements, requiring XBLOCK atomic_adds. Many of those adds would write to the same elements, which would be slow.

You replace that with a Triton block of XBLOCK elements (Triton blocks are fixed size, so you are forced to keep it XBLOCK in size, the worst case). But now that block is mostly zeros, and every nonzero element points to a unique index. At this point you could issue a much smaller number of atomic_adds by masking away the unneeded ones.

In your example you would have:

indices = [0, 0, 1, 1, 1, 2, 2]
values = [0, a+b, 0, 0, c+d+e, 0, f+g]
mask = [False, True, False, False, True, False, True]
tl.atomic_add(ptr+indices, values, mask)

If you removed the masked items (which happens in hardware mask registers), you just get:

indices = [0, 1, 2]
values= [a+b, c+d+e, f+g]
1 Like

Understood! This approach seems fantastic to me! However, there’s a significant challenge when considering a “carry” output in this context.

Let’s delve deeper into how we can transition from the original SchedulerNode to a sorted_add SchedulerNode. The following code snippet demonstrates the input to torch.compile, where x has a size of [10000, 32] and edge_index has a size of [2, 200000]:

def gather_scatter(x, edge_index, reduce="sum"):
    row, col = edge_index
    x_j = x[row]
    return scatter(x_j, col, dim_size=x.size(0), reduce=reduce)

In the context of the Inductor IR, when fusing index_select and scatter_add into a new node scheduler, we have the following:

buf1.users = [NodeUser(node=OUTPUT, can_inplace=False, is_weak=False)]
buf1.group.device = cuda:0
buf1.group.iteration = (6400000, 1)
buf1.sizes = ([200000, 32], [])
buf1.mutations = ['buf0']

class buf1_loop_body:
    var_ranges = {z0: 200000, z1: 32}
    index0 = z0 + 200000
    index1 = z0
    index2 = 32 * indirect1 + z1
    index3 = 32 * indirect0 + z1
    
    def body(self, ops):
        get_index = self.get_index('index0')
        load = ops.load('arg1_1', get_index)
        set_indirect0 = self.set_indirect0(load)
        get_index_1 = self.get_index('index1')
        load_1 = ops.load('arg1_1', get_index_1)
        set_indirect1 = self.set_indirect1(load_1)
        get_index_2 = self.get_index('index2')
        load_2 = ops.load('arg0_1', get_index_2)
        get_index_3 = self.get_index('index3')
        store = ops.store('buf1', get_index_3, load_2, 'atomic_add')
        return store

In this pattern, index_3 relies on indirect0, which is expected to be a sorted axis (assuming we can configure and recognize it in advance). Given this setup, the question arises: How can we map this to a “sorted_add” IR?

This would mainly happen in codegen, not IR.

In the IR you would just replace:

store = ops.store('buf1', get_index_3, load_2, 'atomic_add')

with

store = ops.store('buf1', get_index_3, load_2, 'sorted_add')

The actually tricky part would be codegen for 'sorted_add', since it relies on Triton features that don’t exist yet.

Basically rather than emitting the Triton code for atomic_add:

tl.atomic_add(ptr+indices, values, mask)

you would emit:

new_values = tl.associative_scan(....<new triton scan op would go here>...)
tl.atomic_add(ptr+indices, new_values, mask & (new_values!=0))
1 Like

I suppose you could also simplify the scan op by precomputing the mask based in indices[i] != indice[i+1] – then the scan op would be similar to tl.cumsum, except that it would reset the count to 0 whenever the mask was True. It would be something like a “conditional cumulative sum”.

Though @peterbell10 (who I believe is on vacation, but will hopefully respond once he is back) would know more there, since he implemented some of the Triton scan stuff.

Thank you so much! I find this solution awesome. My understanding of the required front-end changes is as follows:

def gather_scatter(x, edge_index, reduce="sum", sorted=True):
    row, col = edge_index  # col is sorted
    x_j = x[row]
    return scatter(x_j, col, reduce, sorted_axis=sorted)

In this modification, we introduce a sorted flag and add a sorted_axis tag to the scatter operator.

Consequently, we will generate a SchedulerNode that looks something like this:

buf1.users = [NodeUser(node=OUTPUT, can_inplace=False, is_weak=False)]
buf1.group.device = cuda:0
buf1.group.iteration = (6400000, 1)
buf1.sizes = ([200000, 32], [])
buf1.mutations = ['buf0']

class buf1_loop_body:
    var_ranges = {z0: 200000, z1: 32}
    index0 = z0 + 200000
    index1 = z0
    index2 = 32 * indirect1 + z1
    index3 = 32 * indirect0 + z1
    
    def body(self, ops):
        get_index = self.get_index('index0')
        load = ops.load('arg1_1', get_index)
        set_indirect0 = self.set_indirect0(load)
        get_index_1 = self.get_index('index1')
        load_1 = ops.load('arg1_1', get_index_1)
        set_indirect1 = self.set_indirect1(load_1)
        get_index_2 = self.get_index('index2')
        load_2 = ops.load('arg0_1', get_index_2)
        get_index_3 = self.get_index('index3')
        store = ops.store('buf1', get_index_3, load_2, 'sorted_add')

I’m planning to initially try this out with OpenMP backend codegen, as it’s unlikely to be hindered by Triton’s limitations.

2 Likes

Hello @jansel, after some hands-on experimentation and reflection, I’ve come to realize that the solution may not be appropriate when considering tl.associative_scan. This Triton primitive is tailored for associative operations, where the function used for element combination (combine_fn) needs to adhere to the associative property: (a op b) op c should equal a op (b op c) for any elements a, b, and c. However, the principle of segment reduction does not align with associative computation patterns, indicating that it cannot be effectively handled by associative_scan. Additionally, if my understanding is correct, this pattern employs parallel reduction using CUDA’s __shfl primitive, which differs from sequential reduction. The accompanying figure, sourced from Wikipedia, illustrates the concept of a parallel scan.

Inspired by dgSPARSE’s solution of spmm_coo_sorted credit by @hgyhungry, I find it is realistic to codegen a parallel solution for both OpenMP and Triton. Surprisingly, I found OpenMP is even harder than Triton. Here is what I implemented within OpenMP, without a carry input.

void segment_sorted_add(std::vector<float> &res, const std::vector<int> &indices,
                    const std::vector<float> &values) {
  int n = indices.size();

  // Check if indices and values vectors are empty
  if (n == 0)
    return;

#pragma omp parallel
  {
    float acc = 0.0;     // Accumulator for the current index
    int last_index = -1; // Initialize last_index with the first index
#pragma omp for nowait
    for (int i = 0; i < n; ++i) {
      if (i == 0 || indices[i] != last_index) {
        if (last_index != -1) {
          atomic_add(&res[last_index],
                     acc); // Update the last index with accumulated value
        }
        acc = values[i];         // Reset accumulator for new index
        last_index = indices[i]; // Update the last index
      } else {
        // Same index, accumulate the values
        acc += values[i];
      }
    }
    // Update the last index after the loop
    atomic_add(&res[last_index], acc);
  }
}

Here I also provide a simple test case via gist.

The issue with OpenMP arises when dealing with a ‘carry’ input, for example, if the value is a 2D tensor. In such cases, it tends to fail due to the limited register files in CPUs. This necessitates managing a buffer for acc, which requires allocation through malloc or new instructions.

Yes, I am aware of that. That is why I said:

The scan that exists in Triton today couldn’t implement that, because it doesn’t support multiple inputs and doesn’t have a “carry” output (like jax.lax.scan) – but I could imagine extending it to do so.

And:

I suppose you could also simplify the scan op by precomputing the mask based in indices[i] != indice[i+1] – then the scan op would be similar to tl.cumsum, except that it would reset the count to 0 whenever the mask was True. It would be something like a “conditional cumulative sum”.

To elaborate on the simplified one, the op would operate on a tuple of (value, last_mask). The initial values of last_mask would be indices[i] != indice[i+1]. When you combine two values, you always take the right one.

Then scan op would be:

def masked_cumsum_combine(left, right):
   if left.last_mask:
        # the prior value has already been written out, just ignore it
        # this resets the running count to 0 whenever the index changes
        return (right.value, right.last_mask)
   else:
        # the prior write wont happen, need to include it in the sum
        return (left.value + right.value, right.last_mask)

This fits into a basic scan op without a “carry” variant, though it requires a tuple of two inputs.

This will compute a cumulative sum, except whenever the index changes, it resets the running count to zero. You ignore any partial values when the index doesn’t change. Note in this version the ignored/masked-away values won’t be zero. So you can’t use values!=0 for the mask. You will need to use the mask based on the index changing.

Yes, I think we reached a consensus on this, but my idea is that we can use Triton to write the segment_reduce for the carry input without using parallel scan. The key approach is to do a sequential scan within a block.

Here is my code example in triton:

@triton.jit
def spmm_sorted_coo_naive(edge_index, B, C, num_edges, feature_size: tl.constexpr, group_size: tl.constexpr):
    group_id = tl.program_id(0)
    node_offset = group_id * group_size
    f_index = tl.arange(0, feature_size)

    xn = node_offset
    mask = xn < num_edges
    in_node = tl.load(edge_index + xn, mask=mask)  # Load the input node
    out_node = tl.load(edge_index + xn + num_edges,
                       mask=mask)  # Load the output node
    curr_node = out_node
    val = tl.load(B + in_node * feature_size + f_index, mask=mask)
    for ii in range(1, group_size):  # Iterate over the group
        xn = ii + node_offset  # Get the node index
        mask = xn < num_edges  # Check if the node index is valid
        in_node = tl.load(edge_index + xn, mask=mask)  # Load the input node
        out_node = tl.load(edge_index + xn + num_edges,
                           mask=mask)  # Load the output node
        new_val = tl.load(B + in_node * feature_size + f_index, mask=mask)
        if out_node != curr_node:
            # Perform atomic addition
            tl.atomic_add(C + curr_node * feature_size +
                          f_index, val, mask=mask)
            # Reset val for the new row
            val = new_val
            curr_node = out_node
        else:
            # Accumulate val
            val += new_val

    tl.atomic_add(C + out_node * feature_size + f_index, val, mask=mask)

The key idea here is to map the logic index[i] != index[i+1] into cond=out_node != curr_node, and place this logic within a sequential loop.

Thus in this way, instead of writing like

new_values = tl.associative_scan(....<new triton scan op would go here>...)
tl.atomic_add(ptr+indices, new_values, mask & (new_values!=0))

where I believe associative_scan behaves in a parallel reduction rather than sequential reduction,
we can now directly leverage the expression ability of the status quo Triton primitive, and have a version for sorted scatter function with fewer atomic operations.

I have verified the correctness and performance of this code. Here is a test code via gist.
On RTX 3080, this sorted version fused index_scatter reaches 2x speedup compared to the fully atomic kernel. I will try to test more performance benchmarks and conclude the codegen flow of this pattern later.

I think either one could work, though it depends on the sizes which is best. For a very small group size you might not have enough parallelism with your approach.

I think my example above was actually slightly wrong. It would need to be:

# input is namedtuple of (value, rightmost_index)
def masked_cumsum_combine(left, right):
   if left.rightmost_index == right.rightmost_index:
        return (left.value + right.value, right.rightmost_index)
   else:
        return (right.value, right.rightmost_index)

Indeed, I concur that both reduction methods are valid for expressing segment_reduction semantics. However, parallel scanning within a warp proves efficient for smaller feature sizes, while sequential scanning within a thread is more effective for larger sizes.

Addressing the bottlenecks in frontend, IR, and codegen, I’ve outlined a potential roadmap below:

Frontend Design

Rather than utilizing the scatter(..., sorted=True) expression, I’m inclined towards adopting segment_reduce. Here are my reasons for preferring the segment_reduce operator: It’s commonly used in histograms, graph analysis, GNNs, etc. While there’s a closed call for implementation as noted in this Issue, its functionality still lags behind segment_coo from the torch-scatter library. The primary motivation for choosing segment_reduce is its ability to preserve the sorted information of indices from the frontend. Another major aspect is the IR design, which I discuss in the following section.

IR Design

The Inductor typically lowers the scatter operator to Scatter IR, derived from pointwise IR. However, pointwise IR doesn’t efficiently establish loop structures for operators akin to segment_reduce. An ideal Component Diagram is illustrated below:

Codegen Design

Focusing on simplicity, let’s consider the Triton backend as an example. The main challenge in Triton’s codegen is its current lack of support for segment_scan IR. A feasible approach is to enhance the existing scan to accommodate segment scanning. This, however, is a complex task, and I’m still exploring the ToLLVM extension.

An endnote for scan: There are essentially two patterns for scanning: One is a sequential scan within a single thread, and the other is a parallel scan within a warp. Both patterns need to be concerned when extending to segment_reduce.