PyTorch Sparse(GNN) Compiler RFC

Thank you for the information! It’s exactly what I’ve been searching for. I just reached out to the PyG team to inquire why they use broadcast followed by scatter_reduce instead of index_reduce. PyG’s code Ref

I’ve previously considered segment_reduce, but I don’t view it as critically important now. This is because introducing a sorted_index tag to the index_reduce function effectively makes it equivalent to segment_reduce, but with greater versatility.

I’m excited about the possibility of realizing code generation for index_reduce with sorting! However, I’m intrigued by the following statement:

Based on my experience, Triton does not yet support the sorted scan function, and the torch IR does not accommodate such a reduction pattern. Could you provide further clarification on this issue and possibly share more code references (e.g. code link to torch’s repo)? Thank you!