RFC: Polyhedral Optimization Pass for PyTorch Inductor
This RFC proposes adding an optional polyhedral optimization pass to PyTorch Inductor to enable fusion of operations that current fusion heuristics cannot handle. The initial test case demonstrates ~1.4x speedup on RMSNorm + chunking + gating patterns (found in SwiGLU) used in architectures like Llama.
MOTIVATION
PyTorch Inductor currently relies on pattern matching and heuristic-based fusion strategies. While effective for many cases, these approaches fail to fuse certain operation sequences where fusion is mathematically valid and beneficial. For a case study demonstrating the speedup for SwiGLU, see [Example] Benefits of Polyhedral Optimization by morrison-turnansky · Pull Request #3 · morrison-turnansky/pytorch · GitHub
High-Level Design
Add a new opt-in compilation pass that users can enable via compilation flags. This optimization pass would target the loop level IR. We would use a subset of polyhedral optimization specifically targeting tensor workflows.
Scope Constraints (Initial Release)
-
Inference Only: Training support deferred to avoid complexity with backward pass.
- Ultimately these techniques can be extended to support training and inference, but we want to defer the potential complexities of training such as synchronization issues.
-
Subset of Polyhedral: Use a simplified version of polyhedral analysis focusing on identifying fusion opportunities and therefore minimizing overhead during compilation.
-
We will focus on identifying fusion opportunities via loop level dependency analysis. Initially, we will follow a general heuristic that fusion will be profitable (see testing strategy).
-
We can later expand this to use optimization functions standard in polyhedral analysis to rigorously determine if fusion will be profitable.
-
Refer to reference for an example of lightweight polyhedral analysis.
-
-
Opt-In: Default disabled to ensure zero impact on existing workflows.
Dynamic vs Static Shapes
A primary benefit of polyhedral analysis is that it handles dynamic shapes by design. Loop level dependency analysis depends on the rank of the tensor, not the specific shape.The end goal would be full dynamic shape support.
Breaking Changes
None. This is an additive feature:
-
Default behavior unchanged
-
Requires explicit opt-in via compilation flag
-
Falls back to standard Inductor fusion when polyhedral analysis is unavailable or unprofitable
Testing Strategy
-
There are two areas to test at a function level
-
Does our pass result in the generated kernel being fused?
- This can be verified by a test similar to test_fusion_codegen in the example. This follows the pattern of tests in pytorch/test/inductor/test_loop_ordering.py.
-
Numerical Accuracy
- This should be treated similarly to eager/inductor tests. Eager will be the source of truth.
-
-
End to End Performance Analysis
- Verify a set of desired/common models actually see a performance increase. Monitor performance regressions.
We propose that for a set of graphs where we expect polyhedral to make substantial changes, i.e. Swiglu + others, we add a test verifying that fusion occurred and a numerical accuracy test.
References
-
See simplified polyhedral form as a precedent for using a subset of polyhedral optimizations on tensor workflows: https://mlir.llvm.org/docs/Rationale/RationaleSimplifiedPolyhedralForm/
-
For a high level background on polyhedral optimization and a discussion on its wide usage in compilers: http://polyhedral.info/
-
For a detailed discussion on polyhedral see https://pliss2019.github.io/albert_cohen_slides.pdf