Author: @jgong5, @leslie-fang-intel, @chunyuan-w
Contributors: @jgong5, @leslie-fang-intel, @chunyuan-w, @sanchitintel, @frost-intel
TL;DR: We are excited to share the ongoing development of the “max-autotune” mode for the Inductor CPU backend in torch.compile
(see the RFC here). This feature profiles multiple implementations of operations at compile time and selects the best-performing one, trading longer compilation times for improved runtime performance. This enhancement is particularly beneficial for GEMM-related operations.
In the Inductor CPU backend, we’ve introduced a C++ template-based GEMM implementation as an alternative to the ATen-based approach that relies on oneDNN and MKL libraries. This is similar to the max-autotune mode on CUDA, where implementations from ATen, Triton, and CUTLASS are considered.
Design
GEMM (General Matrix Multiply) is arguably the most critical operation in AI workloads, as it often dominates execution time. However, optimizing GEMM for modern computing architectures is challenging due to several factors:
- Data Reuse: Efficient GEMM optimization hinges on maximizing data reuse while striving for high device FLOPS utilization, leveraging the latest hardware accelerations. The optimization process involves determining how to decompose work across multiple computing units, organizing data layouts, and scheduling data accesses to ensure coalesced, fast data movement with good locality. This also includes effective use of registers and maximizing the utilization of hardware accelerators.
- Fusion with Other Operations: GEMM operations in AI workloads are often fused with other operations, particularly pointwise operations. These fusions can occur before GEMM, such as in weight-only quantized GEMMs, or after, such as with various nonlinearities. More complex fusion patterns might involve multiple GEMMs, like those found in MLPs and Attention mechanisms. The code generation must be flexible enough to accommodate these diverse fusion patterns.
Optimizing GEMM on CPUs presents its own set of challenges due to the limited number of registers and the lack of addressable fast memory. This requires careful register allocation, efficient data layout arrangement, and strategic scheduling of data accesses to fully utilize cache systems. Below, we outline how these challenges are being addressed in the C++ GEMM template.
Key Techniques in the C++ GEMM Template
To tackle the challenges mentioned, we have incorporated the following techniques into our C++ GEMM template:
- Multi-Level Blocking for Enhanced Data Locality:
- Thread Blocking: Matrices are partitioned based on the number of threads along the M, N, and K dimensions, assigning each thread a tile: Mt, Nt, Kt. The goal is to maximize thread occupancy with optimal per-thread data reuse by favoring square-sized MxN blocks and minimizing cross-thread synchronization overhead.
- Cache Blocking: The “Mt, Nt, Kt” block is further partitioned according to CPU cache sizes into “Mc, Nc, Kc” to ensure good cache locality.
- Register Blocking: The “Mc, Nc, Kc” block is further divided based on CPU register sizes into “Mr, Nr, Kr” to maximize register file reuse and leverage SIMD hardware accelerators effectively.
-
Pre-Packing Weights for Efficient Cache Usage:
In inference scenarios, which are the primary focus of CPU AI workloads, model weights are constant. We pre-pack these weights during compilation, ensuring that data accesses within cache blocks are contiguous.
-
Architecture-Specific Instruction Selection and Register Allocation:
At the innermost level of computation, the choice of specific instructions (e.g., AVX2, AVX512, AMX, NEON) and register allocation for matrices A, B, and C (i.e., values for Mr, Nr, and Kr) is guided by the CPU architecture and GEMM sizes. Heuristics favor faster dot-product accelerators, efficient register usage, and optimal thread occupancy. -
Flexible Fusion of Epilogue Pointwise Operations:
Arbitrary epilogue pointwise operations can be fused with GEMMs by integrating existing C++ vectorized codegen, stitching the generated loop nest into the appropriate inner loop level of the C++ GEMM template.
for (auto Mc : ...) {
for (auto Nc : ...) {
// Compute GEMM block Mc x Nc
...
// Epilogue fusion generated by CPP vectorized codegen
for (int i = 0; i < Mc; i++) {
// main vec loop
for (int j = 0; j < Nc/16*16; j+=16) {
...
}
// tail loop
for (int j = Nc/16*16; j < Nc; j++) {
...
}
}
}
}
Template Design for Reusability and Extensibility
The template is designed with two levels of abstraction: the “C++ GEMM template” and the “C++ micro GEMM.” The “C++ GEMM template” handles thread blocking, cache blocking, pre-packing weights, and generating outer loop levels for epilogue fusions. It then calls the “C++ micro GEMM,” which manages register blocking, instruction selection, and CPU-specific optimizations. The “C++ micro GEMM” is architecture-specific and can be extended with various register allocation and instruction selection algorithms. The “C++ GEMM template” is designed to be shared across multiple CPU architectures via the micro GEMM abstraction. The CPU architecture module provides details about CPU capabilities, such as instruction sets and multi-level cache sizes, which are utilized by both abstractions.
Current Status
The RFC summarizes the current development status and future plans. We are currently focused on developing and fine-tuning operations involving a single GEMM. We have covered most popular data types, including FP32, BF16, FP16, and INT8, with epilogue fusions for x86 CPUs. The “M” dimension can be static or dynamic, while “N” and “K” are assumed to be static—a typical scenario in most workloads. Although development is still ongoing, we have already observed promising speedups over pure ATen-based GEMMs, as measured by three benchmark suites and LLM inference tests. We plan to present our work, including technical details and performance results, at the upcoming PyTorch conference.
Next Steps
Future work will focus on the following areas:
- Expanding support for single GEMMs with weight-only quantization, batch matmul, unpacked weights, and dynamic shapes for “N” and “K.”
- Extending support to operations involving multiple GEMMs, such as MLPs and Attention mechanisms.
- Continuing performance tuning, including improvements in thread and cache blocking and micro-gemms.
- While the current GEMM template already provides basic support for non-x86 CPUs via the ATen Vectorized abstraction, we encourage community contributions to enhance support and optimization for other CPU architectures.