The feature, motivation and pitch
Torch.compile has been included in Pytorch 2.0 or late versions and has shown that it is a powerful way to speed up Pytorch code greatly. Pytorch inductor has pioneered graph-mode execution, significantly accelerating a spectrum of machine learning applications, including CNNs, Transformers, and Graph Neural Networks (GNNs). Despite these advancements, the end-to-end optimization for PyG-based GNNs remains suboptimal. The root of this inefficiency lies in the current implementation of the fused scatter_add kernel, e.g. fused_index_new_zeros_scatter_add_0
. Also, regarding profiling results, the fused scatter_add kernel has accounted for over 80% of the running time of typical GNN models. The code provided here shows that codegen of fused gather scatter will lower atomic operations, which significantly affects the performance. This proposal sets forth a strategy to refine this mechanism, focusing initially on CPU backends due to their less complex thread hierarchy than GPUs and the current limitations of the Triton backendâs warp-level specification. Check here for the triton limitation for Sparse Kernel Codegen.
Alternatives
Option 1
The first option is to develop an independent compiler pass specifically for sparse tensors. However, this approach is challenging, given two major reasons. First of all, the backend of the Inductor scheduler is majorly designed for pointwise kernel fusion. Secondly, the input for SparseTensor can vary significantly, both in terms of the programming interface and the format used. For instance, PyTorch Geometric (PyG) employs a proprietary SparseTensor input, which differs from the native sparse tensor format utilized by PyTorch, as detailed in the official PyTorch documentation.
Option 2
Another alternative method involves registering the sparse tensor and sparse SpMM tensor within torch.ops. The Inductor would then default to calling an external kernel for SpMM-related operations (similar to segment_reduce). This approach has its downsides too, as it may diminish opportunities to fuse kernels based on SpMM-like or scatter_reduce operations.
Additional context
Objective:
The primary goal is to enhance the PyTorch inductorâs capability for optimizing scatter-based GNN end-to-end models, specifically targeting CPU backends. The ideal codegen result will be similar to CSR part of aten::scatter_reduce kernel.
Expected Performance:
These enhancements are expected to bridge the gap between the current and optimal performance for GNNs, thereby unlocking new efficiencies and potential within PyTorchâs framework. Technically. this project will be targeted at:
- Codegen result will achieve the kernel-wise result proposed by this code, which showed the performance speedup from 4-7x.
- End2end result will meet the expectation of 2-3x, considering the breakdown profiling of mainstream GNN model.
Techniques:
We will be working on the following new ideas:
- Pattern Matching for Scatter-Based Operations:
Investigation reveals that kernels involving scatter operations devolve to atomic-based IRs, which lead to subpar code generation. A pattern-matching strategy will be developed to identify and address these inefficiencies. Note that this pattern match happens ininductor/scheduler.py
. We will add a sparse scheduler pass after the fusion process, filtering the scheduler node that needs to be optimized. Below is the possible code for reference.
self.create_foreach_nodes()
self.topological_sort_schedule()
self.logged_slow_fusion = set()
self.fuse_nodes()
self.sparse_aware_nodes() # sparse specific pass
- Preprocessing for Sparse Formats:
Recognizing the unique requirements of GNNs, we will implement a detection mechanism for input data indicative of sparsity. Such data will be preprocessed and converted into Compressed Sparse Row (CSR) format, making it more amenable to optimization for sparse operations. - IR Transformation:
We propose to transform the conventional schedule nodes into new, sparse-specific schedule nodes within the PyTorch inductorâs IR. This will facilitate the generation of more efficient execution plans for sparse data. The code below shows a possible IR using inductorâs primitive:
buf1.group.device = cpu
buf1.group.iteration = ((nodes, feats), ())
buf1.sizes = ([nodes, feats [])
buf1.mutations = ['buf0']
class buf1_loop_body:
var_ranges = {z0: nodes, z1: feats}
index0 = z0
index1 = z0 + 1
index2 = feats*indirect1 + z1
index3 = feats*z0 + z1
def body(self, ops):
start_index = self.get_index('index0')
start = ops.load('csr_ptr', get_index)
end_index = self.get_index('index1')
end = ops.load('csr_ptr', get_index)
out_idx = self.get_index('index3')
acc = ops.load('C', out_idx) # acc = 0
for i in range(start, end):
set_indirect1 = self.set_indirect1(i)
col_index = self.get_index('indirect1')
col = ops.load('csr_ind', col_index)
B_index = self.get_index('index2')
B_val = ops.load('B_index')
acc = ops.add(acc, B_val)
store = ops.store('buf1', out_idx, acc, None)
return store
- Sparse-Aware Code Generation:
Leveraging the modified IR, we will develop a code generation engine capable of producing highly optimized code for executing sparse operations in CSR format on CPU backends. The ideal final code can be checked here.
cc @mingfeima @yanbing-j @xinchen9 @zhang677