TorchInductor Update 4: CPU backend started to show promising performance boost

It’s Jiong Gong (@jgong5) from the Intel team working on PyTorch optimization for CPU. In this post, I’d like to give an update on the recent progress of CPU backend of TorchInductor, the new DL compiler of PyTorch. Designed to support multiple device backends, TorchInductor provides backends for both CPU and NVIDIA GPU. There has been great progress on GPU backend optimization for training workloads (see this for details). On CPU side, since a significant portion of DL workloads running on CPU are DL inference, we started off optimizing CPU inference as our first step. We started the efforts in early October with a low performance baseline (see table 1-3 below) at that point of time, and we are pleased to bring the improvements to the table.

Table 1: Pass Rate
+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor | 96%, 52/54 | 100%, 44/44 | 89%, 54/61  |
+----------+------------+-------------+-------------+
Table 2: Geomean Speedup over Eager Mode (Multi-threaded, 32core, large batch)
+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor |   1.04x    |    1.03x    |    1.06x    |
+----------+------------+-------------+-------------+
Table 3: Geomean Speedup over Eager Mode (Single-threaded, single batch)
+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor |   1.08x    |    1.02x    |    1.10x    |
+----------+------------+-------------+-------------+

What We Have Done

Applying know-hows from existing work is our strategy to get TorchInductor CPU backend up to speed quickly. We leveraged the optimization know-hows from Intel® Extension for PyTorch (IPEX) for Conv/GEMM compute-bound ops and PyTorch ATen CPU kernels for memory-bound ops.

IPEX is one of the most efficient inference backends of TorchDynamo. It has Conv/GEMM with post-op fusions and weight prepacking based on oneDNN and MKL performance library (refer to this page for more details about these IPEX optimization). @XiaobingSuper has been contributing these optimizations to the TorchInductor via the FX fusion path:

All of the Conv/GEMM post-op fusions have been landed and weight prepacking support is on the way.

ATen CPU kernels for memory-bound ops (e.g., pointwise and reduction) have been optimized with OpenMP multi-threaded parallelism and explicit vectorization and are highly performant by their own. On the other hand, the C++/OpenMP kernels generated by TorchInductor were also parallelized with OpenMP but auto vectorized by C++ compiler (via “omp simd” pragma). They are already good enough for some simple kernels but still have big gaps for more complicated ones, e.g., Softmax which is up to 2.5X slower than eager. By studying the gaps, @jgong5 identified three missing performance features, explicit vectorization (the one which has the major impact on performance) and inplace buffer and better buffer reuse heuristics (performance impact is relatively minor). @EikanWang contributed the explicit vectorization support (#87068, #87356, #88160, #88482, #88736, #89263, #89274). @jgong5 added the rest (#87037, TorchDynamo Repo: #1468, #1486).

Thanks to the guidance and warm help from @jansel, @Chillee and @desertfire, the modular design of TorchInductor and the efficiency from Python-based development, we were able to turn around fast and added most of the FP32 optimizations quickly. We are seeing promising performance results. The FP32 inference performance of TorchInductor has been improved a lot on all three key DL benchmarks: TorchBench, HuggingFace and TIMM. In particular, the inference performance on HuggingFace models has already been better than what was achieved with Intel® Extension for PyTorch. See the detailed data below. We also added TorchDynamo+IPEX performance numbers as a reference.

Table 4: Pass Rate (New)
+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor | 96%, 52/54 | 100%, 44/44 | 95%, 58/61  |
+----------+------------+-------------+-------------+
Table 5: Geomean Speedup (New) over Eager Mode (Multi-threaded, 32core, large batch)
+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor |   1.16x    |    1.26x    |    1.08x    |
|   ipex   |   1.40x    |    1.14x    |    1.42x    |
+----------+------------+-------------+-------------+
Table 6: Geomean Speedup (New) over Eager Mode (Single-threaded, single batch)
+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor |   1.23x    |    1.23x    |    1.21x    |
|   ipex   |   1.35x    |    1.18x    |    1.54x    |
+----------+------------+-------------+-------------+

Next Steps

Our immediate next step is to make TorchInductor the best performing TorchDynamo backend for CPU FP32 inference. There are still some known performance enhancements to add according our model performance profiling:

  1. We are still a bit conversative in applying explicit vectorization. More code patterns can get benefit from vectorization, e.g., masked load in kernels like max pooling (#1914), indirect indexing a vector in kernels like embedding lookup (#1851), the mixed data type case that uses double-precision scalars (#1917) etc.
  2. The generated transposition kernel is still slow (#1915).
  3. The LSTM kernel can be further optimized with oneDNN library (#1918).
  4. Constant folding can save compute cost on constants in modes like MobileBertForMaskedLM (#1860).

Besides the optimization mentioned above, we are also adding features to reduce framework overhead which is helpful especially for smaller models: @chunyuan-w is working on the CPP wrapper which aims to reduce the Python wrapper overhead and also can help the model export without Python dependencies. @Guobing-Chen is working on oneDNN kernel cache support to reduce the integration overhead.

On the functionality side, we strive to improve the model pass rate to 100% and also fix and re-enable those failing unit tests. @Valentine223 and @zhuzhaozhe are working on them:

  1. [Inductor] incorrect result of vision_maskrcnn
  2. [Inductor] Support deterministic parallel reduction in CPP backend

In the long run, we plan to expand our optimization to cover both training and inference with various data types and with optimal codegen algorithms. Below is something in our mind now and welcome to requirements/ideas/suggestions from the community:

  1. Low-precision (BF16 and INT8) inference optimization
  2. Training optimization
  3. Loop tiling
  4. Autotune
  5. Further fusion optimization of Conv/GEMM kernels.
  6. Explore alternative codegen paths: clang/llvm/triton
9 Likes
  • Test configuration for table 1-3: TorchDynamo commit: 2e6737a46d6c4eea5b888345c1ffc7a282eb8153, HW: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
  • Test configuration for table 4-6: Tested with PyTorch commit: b843f4db0a26aae6536e6b971f73bcc5af21c90a + #89109, #89209, HW: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz