For example, if I do max operation:
def fn(x):
return torch.max(x, 1)
opt_fn = torch.compile(fn)
x = torch.randn(16, 2**20, device="cuda")
y = opt_fn(x)
I found the triton code generated has two kernels to do this reduction op, the first kernel do a smaller dimension reduction, result in for example (16, 2**8) tensor, and the second kernel will do the last reduction and result in (16, 1) tensor. This magic happens in code in ir.py.
But could anyone help me on why it need this strategy? Because of the SM’s capacity or maybe for performance considerations?
Thanks very much!