TorchInductor Update 6: CPU backend performance update and new features in PyTorch 2.1

In our previous blogs of TorchInductor Update 4 and TorchInductor Update 5, @jgong5 @EikanWang shared the progress and technical deep-dive about the optimization work for the Inductor C++/OpenMP backend. In this post, I will introduce the exciting new features and optimizations in PyTorch 2.1, with a specific focus on the advancements made in the Inductor CPU backend. Our goal is to enhance Out-Of-the-Box applicability and deliver better performance on CPU platforms:

  • New data type support – BF16 and INT8 are added to speed up low-precision inference.
  • FP32 Dynamic shape input is also supported as to benefit more dynamic scenarios.
  • C++ Wrapper is added to reduce Python runtime overhead as comparing with default python wrapper.
  • Flash attention based scaled dot product attention (SDPA) kernel is now available on CPU and enabled in inductor backend via automatic pattern matching.

New data type support

BFloat16 Inference Path

The bfloat16 inference path is now enabled for the inductor CPU backend which are based on automatic mixed precision to implement the data type conversion. We apply the similar performance optimizations introduced in TorchInductor Update 5 for bfloat16 inference path which categorizes operations into two types: Conv/GEMM and non-Conv/GEMM, which we optimize differently. For Conv/GEMM operations, we leverage the oneDNN performance library, while for non-Conv/GEMM element-wise and reduction operations, we utilize Inductor C++ codegen for optimization.

Additionally, we apply post-op fusion and weight prepacking using the oneDNN library for Conv/GEMM operations. For non-Conv/GEMM element-wise and reduction operations, we conduct bfloat16 legalization which always loads the BF16 tensor as FP32 for computation and converts back to BF16 after the computation. In this way, we are able to reuse explicit vectorization support in C++ codegen as much as possible to achieve optimal performance.

This approach has been measured and proven effective on popular deep learning models among 3 test suits (torchbench, huggingface, timm_models).

Passrate (Single-Socket Multi-threads)

+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor | 91%, 63/69 | 98%, 45/46  |  92%, 56/61 |
+----------+------------+-------------+-------------+

Passrate (Single-Core Single-thread)

+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor | 93%, 64/69 | 96%, 44/46  |  89%, 54/61 |
+----------+------------+-------------+-------------+

Geometric mean speedup (Single-Socket Multi-threads)

+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor |   1.81x    |    1.25x    |    2.35x    |
+----------+------------+-------------+-------------+

Geometric mean speedup (Single-Core Single-thread)

+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor |   1.74x    |    1.28x    |    1.29x    |
+----------+------------+-------------+-------------+

The above data are measured on AWS m7i.16xlarge instance (SPR).

INT8 Inference with Post Training Static Quantization

PyTorch 2.1 introduces a new export quantization flow, expected to significantly increase model coverage, enhance programmability, and simplify the user experience. We have enabled the Inductor CPU as one of the backends for this new quantization flow.

As part of the quantization frontend, we’ve enabled the in-tree X86InductorQuantizer, designed to apply post-training static quantization recipes specifically tailored for the Inductor CPU backend. The optimization in Inductor is similar as other data types. We match quantization patterns, apply weight prepacking, and use post-op fusion with the oneDNN library for Conv/GEMM operations. For non-Conv/GEMM element-wise and reduction operations, we achieve optimal performance by enabling explicit vectorization with uint8 data types in our C++ codegen.

All CNN models from TorchBench test suite have been measured and proven effective when comparing with Inductor FP32 inference path.

+----------+--------------------+---------------------------------+
| Compiler |  Geometric Speedup | Geometric Related Accuracy Loss |
+----------+--------------------+---------------------------------+
| inductor |   3.25x, 12/12     |         0.44%, 12/12            |
+----------+--------------------+---------------------------------+

The above data are measured on AWS c6i.16xlarge instance (ICX).

FP32 Dynamic Shape Inference Path

The fp32 dynamic shape inference path is now enabled for the inductor CPU backend. We fixed several functional/accuracy issues (#105651, #105314, #103579, #103511, #103147, #102263, #101793, #100230) to improve the passrate of fp32 dynamic shape inference path in three test suits (torchbench, huggingface, timm_models).

Passrate (Single-Socket Multi-threads)

+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor | 77%, 60/78 | 100%, 46/46 |  97%, 59/61 |
+----------+------------+-------------+-------------+

Passrate (Single-Core Single-thread)

+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor | 92%, 72/78 | 100%, 46/46 |  97%, 59/61 |
+----------+------------+-------------+-------------+

We apply the similar performance optimizations as fp32 static shape inference path to dynamic shape path. The difference compared to the static shape path is that we only have information of the symbolic shapes for activations during compile time. Therefore, we will make assumptions about the true activation shape when performing the Conv/GEMM weight prepack.

Geometric mean speedup (Single-Socket Multi-threads)

+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor |   1.35x    |    1.15x    |    1.79x    |
+----------+------------+-------------+-------------+

Geometric mean speedup (Single-Core Single-thread)

+----------+------------+-------------+-------------+
| Compiler | torchbench | huggingface | timm_models |
+----------+------------+-------------+-------------+
| inductor |   1.48x    |    1.15x    |    1.48x    |
+----------+------------+-------------+-------------+

The above data are measured on AWS c6i.16xlarge instance (ICX).

C++ Wrapper Path

We propose C++ Wrapper in TorchInductor as a new Prototype feature in PyTorch 2.1.

Python, as the primary interface of PyTorch, is easy to use and efficient for development and debugging. The Inductor’s default wrapper generates Python code to invoke generated kernels and external kernels. However, in deployments requiring high performance, Python, as an interpreted language, runs relatively slower compared to compiled languages.

We implemented an Inductor C++ Wrapper by leveraging the PyTorch C++ APIs to generate pure C++ code that combines the generated and external kernels. This allows for the execution of each captured Dynamo graph in pure C++, thereby reducing the Python overhead within the graph.

To activate this prototype feature, users need to add the following API:

import torch._inductor.config as config
config.cpp_wrapper = True

This will speed up your models by reducing the Python overhead of the Inductor wrapper.

For light workloads where the overhead of the Python Wrapper is more dominant, C++ Wrapper demonstrates higher performance boost ratio. We grouped the models in torchbench, huggingface and timm_models per the average inference time of one iteration and cateogorized them into small, medium and large categories. Below are the geomean speedup achieved by the C++ Wrapper in comparison to the default Python Wrapper.

  • FP32 static shape mode

    C++ Wrapper demonstrated up to 1.06x speedup versus existing Python Wrapper for single socket multi-threads scenario and up to 1.13x speedup for single core single-thread case.

    Geometric mean speedup (Single-Socket Multi-threads)

    +----------+---------------------+-----------------------------+------------------+
    | Compiler | Small (t <= 0.04s)  | Medium (0.04s < t <= 1.5s)  | Large (t > 1.5s) |
    +----------+---------------------+-----------------------------+------------------+
    | inductor |       1.06x         |            1.01x            |      1.00x       |
    +----------+---------------------+-----------------------------+------------------+
    

    Geometric mean speedup (Single-Core Single-thread)

    +----------+---------------------+-----------------------------+------------------+
    | Compiler | Small (t <= 0.04s)  | Medium (0.04s < t <= 5.0s)  | Large (t > 5.0s) |
    +----------+---------------------+-----------------------------+------------------+
    | inductor |       1.13x         |            1.02x            |       1.01x      |
    +----------+---------------------+-----------------------------+------------------+
    
  • FP32 dynamic shape mode

    C++ Wrapper brings similar performance boost with dynamic shape, up to 1.05x speedup compared to Python Wrapper for single socket multi-threads scenario and up to 1.14x speedup for single core single-thread case.

    Geometric mean speedup (Single-Socket Multi-threads)

    +----------+---------------------+-----------------------------+------------------+
    | Compiler | Small (t <= 0.04s)  | Medium (0.04s < t <= 1.5s)  | Large (t > 1.5s) |
    +----------+---------------------+-----------------------------+------------------+
    | inductor |       1.05x         |            1.01x            |      1.00x       |
    +----------+---------------------+-----------------------------+------------------+
    

    Geometric mean speedup (Single-Core Single-thread)

    +----------+---------------------+-----------------------------+------------------+
    | Compiler | Small (t <= 0.04s)  | Medium (0.04s < t <= 5.0s)  | Large (t > 5.0s) |
    +----------+---------------------+-----------------------------+------------------+
    | inductor |       1.14x         |            1.02x            |      1.01x       |
    +----------+---------------------+-----------------------------+------------------+
    
  • BF16 static shape mode

    C++ Wrapper achieved up to 1.09x speedup versus existing Python Wrapper for single socket multi-threads scenario and up to 1.17x speedup in single core single-thread scenario.

    Geometric mean speedup (Single-Socket Multi-threads)

    +----------+---------------------+-----------------------------+------------------+
    | Compiler | Small (t <= 0.02s)  | Medium (0.02s < t <= 0.3s)  | Large (t > 0.3s) |
    +----------+---------------------+-----------------------------+------------------+
    | inductor |       1.09x         |            1.03x            |      1.04x       |
    +----------+---------------------+-----------------------------+------------------+
    

    Geometric mean speedup (Single-Core Single-thread)

    +----------+---------------------+-----------------------------+------------------+
    | Compiler | Small (t <= 0.02s)  | Medium (0.02s < t <= 1.5s)  | Large (t > 1.5s) |
    +----------+---------------------+-----------------------------+------------------+
    | inductor |       1.17x         |            1.04x            |      1.03x       |
    +----------+---------------------+-----------------------------+------------------+
    

The above data are measured on AWS c6i.16xlarge instance (ICX) for FP32 and AWS m7i.16xlarge instance (SPR) for BF16.

SDPA Optimization

The fused SDPA is now enabled for the CPU backend. It is one type of the optimized SDPA algorithms designed for memory-bound problems, with better parallelism and memory access patterns. As the fused SDPA for the CUDA backend was already introduced in PyTorch 2.0, this feature fills the gap between CPU and CUDA.

The flash attention kernel, one fused SDPA algorithm, is added for CPU. Both forward and backward paths are implemented for data types float32 and bfloat16. With the technique of blocking, we could do the fusion of the gemms and softmax for each block at once. The causal attention mask is also supported with the early termination. In addition, we write an SDPA selecting function to automatically choose one SDPA implementation among several ones. In general, flash attention has a higher priority than the unfused SDPA. For the SDPA-related models without explicitly calling SDPA, the pattern matcher in Inductor could combine several small operators into one SDPA operator.

We have measured the SDPA-related models and proven effective when comparing with the unfused SDPA.

Geometric mean speedup (Single-Socket Multi-threads)

+----------+------------------------+------------------------+
| Compiler | Geometric Speedup FP32 | Geometric Speedup BF16 |
+----------+------------------------+------------------------+
| inductor |      1.15x, 20/20      |      1.07x, 20/20      |
+----------+------------------------+------------------------+

Geometric mean speedup (Single-Core Single-thread)

+----------+------------------------+------------------------+
| Compiler | Geometric Speedup FP32 | Geometric Speedup BF16 |
+----------+------------------------+------------------------+
| inductor |      1.02x, 20/20      |      1.04x, 20/20      |
+----------+------------------------+------------------------+

The above data are measured on AWS c6i.16xlarge instance (ICX) for FP32 and AWS m7i.16xlarge instance (SPR) for BF16.

Summary

This blog post from the Intel PyTorch team provides an update on the new features and performance optimizations introduced in the Inductor C++/OpenMP backend with PyTorch 2.1:

  • BFloat16 automatic mixed precisioin inference and fp32 dynamic shape inference are enabled as Beta feature with proven accuracy/performance in three test suits (torchbench, huggingface, timm_models).
  • The features of INT8 post training static quantization, C++ Wrapper, and flash-attention based scaled dot product algorithm are enabled as prototype feature.

As next step, we plan to expand our optimization to improve the model pass rate with dynamic shape path, enhance inference path with quantization, cover float16 datatype support, and CPU training path.

Many thanks to @jansel, @desertfire, @eellison, @jerryzh168 and @Chillee for their invaluable contributions and unwavering support during the development.

3 Likes