PyTorch Sparse(GNN) Compiler RFC

Understood! This approach seems fantastic to me! However, there’s a significant challenge when considering a “carry” output in this context.

Let’s delve deeper into how we can transition from the original SchedulerNode to a sorted_add SchedulerNode. The following code snippet demonstrates the input to torch.compile, where x has a size of [10000, 32] and edge_index has a size of [2, 200000]:

def gather_scatter(x, edge_index, reduce="sum"):
    row, col = edge_index
    x_j = x[row]
    return scatter(x_j, col, dim_size=x.size(0), reduce=reduce)

In the context of the Inductor IR, when fusing index_select and scatter_add into a new node scheduler, we have the following:

buf1.users = [NodeUser(node=OUTPUT, can_inplace=False, is_weak=False)]
buf1.group.device = cuda:0
buf1.group.iteration = (6400000, 1)
buf1.sizes = ([200000, 32], [])
buf1.mutations = ['buf0']

class buf1_loop_body:
    var_ranges = {z0: 200000, z1: 32}
    index0 = z0 + 200000
    index1 = z0
    index2 = 32 * indirect1 + z1
    index3 = 32 * indirect0 + z1
    
    def body(self, ops):
        get_index = self.get_index('index0')
        load = ops.load('arg1_1', get_index)
        set_indirect0 = self.set_indirect0(load)
        get_index_1 = self.get_index('index1')
        load_1 = ops.load('arg1_1', get_index_1)
        set_indirect1 = self.set_indirect1(load_1)
        get_index_2 = self.get_index('index2')
        load_2 = ops.load('arg0_1', get_index_2)
        get_index_3 = self.get_index('index3')
        store = ops.store('buf1', get_index_3, load_2, 'atomic_add')
        return store

In this pattern, index_3 relies on indirect0, which is expected to be a sorted axis (assuming we can configure and recognize it in advance). Given this setup, the question arises: How can we map this to a “sorted_add” IR?