Understood! This approach seems fantastic to me! However, there’s a significant challenge when considering a “carry” output in this context.
Let’s delve deeper into how we can transition from the original SchedulerNode to a sorted_add SchedulerNode. The following code snippet demonstrates the input to torch.compile, where x has a size of [10000, 32] and edge_index has a size of [2, 200000]:
def gather_scatter(x, edge_index, reduce="sum"):
row, col = edge_index
x_j = x[row]
return scatter(x_j, col, dim_size=x.size(0), reduce=reduce)
In the context of the Inductor IR, when fusing index_select and scatter_add into a new node scheduler, we have the following:
buf1.users = [NodeUser(node=OUTPUT, can_inplace=False, is_weak=False)]
buf1.group.device = cuda:0
buf1.group.iteration = (6400000, 1)
buf1.sizes = ([200000, 32], [])
buf1.mutations = ['buf0']
class buf1_loop_body:
var_ranges = {z0: 200000, z1: 32}
index0 = z0 + 200000
index1 = z0
index2 = 32 * indirect1 + z1
index3 = 32 * indirect0 + z1
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('arg1_1', get_index)
set_indirect0 = self.set_indirect0(load)
get_index_1 = self.get_index('index1')
load_1 = ops.load('arg1_1', get_index_1)
set_indirect1 = self.set_indirect1(load_1)
get_index_2 = self.get_index('index2')
load_2 = ops.load('arg0_1', get_index_2)
get_index_3 = self.get_index('index3')
store = ops.store('buf1', get_index_3, load_2, 'atomic_add')
return store
In this pattern, index_3 relies on indirect0, which is expected to be a sorted axis (assuming we can configure and recognize it in advance). Given this setup, the question arises: How can we map this to a “sorted_add” IR?