I think you are looking for torch.Tensor.index_add_ — PyTorch 2.2 documentation
More generally, in PyTorch we already have a (non-public lol) segment_reduce under torch._segment_reduce. We could:
- Make this function public, write proper docs, clean its API if needed, etc.
- Have codegen for this function. This can already be done in triton and in our IR, as our IR supports multiple input reductions.
Would this give you enough support to implements your models?