What is the split strategy for large-dimensional reduction operations in triton code generation?

For example, if I do max operation:

def fn(x):
    return torch.max(x, 1)

opt_fn = torch.compile(fn)
x = torch.randn(16, 2**20, device="cuda")
y = opt_fn(x)

I found the triton code generated has two kernels to do this reduction op, the first kernel do a smaller dimension reduction, result in for example (16, 2**8) tensor, and the second kernel will do the last reduction and result in (16, 1) tensor. This magic happens in code in ir.py.

But could anyone help me on why it need this strategy? Because of the SM’s capacity or maybe for performance considerations?

Thanks very much!

This is split reductions. The goal is to get good utilization, as if you’re reducing on dim=1, the naive strategy is to assign work to SMs based on dim=0, but in your case the tensor only has 16 rows and this wouldn’t fully utilize the GPU. Splitting dim=1 lets us get full utilization at the cost of one small final reduction at the end.

2 Likes