TorchInductor Update 9: Harden Vectorization Support and Enhance Loop Optimizations in TorchInductor CPP Backend

Author: @jgong5, @leslie-fang-intel
Contributors: @jgong5, @leslie-fang-intel, @CaoE, @zhuhaozhe, @jiayisunx

TL;DR: Contributors from Intel have been enhancing the vectorization support and improving loop optimizations in the TorchInductor CPP backend. Our goal for vectorization is to support all Inductor IR operations, which we’ve nearly achieved by enabling support for the most common data types, operations, indexing methods, and scalarization as fallbacks. We’ve also developed loop optimizations like outer loop fusion and loop splitting to further boost performance.

Hardening Vectorization Support

It’s been a while since our last update on the vectorization status in the TorchInductor CPP backend (previous post). In this post, we want to share our recent progress in strengthening vectorization support and optimizing loop transformations.

When we began optimizing the TorchInductor CPP Backend, we focused on the most common use cases from three benchmark suites—TorchBench, Hugging Face, and TIMM—following the Pareto principle. We used the CppVecKernelChecker to scan Inductor Loop IR and conditionally enable vectorizations where possible. This approach was effective, yielding significant performance improvements with reasonable development effort. However, it also left some performance gains untapped and introduced maintenance challenges due to the limitations built into the CppVecKernelChecker.

Our goal is to make the CPP codegen support vectorization for all Inductor IR inputs and eliminate the need for CppVecKernelChecker. While vectorization can sometimes result in worse performance than its scalar counterpart, we employ heuristics (explained later) to choose the best configurations. The key takeaway is that we no longer fall back to scalar kernels due to functional gaps. Below, you’ll see diagrams comparing the current and expected workflows. With all the efforts introduced in this post, we are finally able to remove CppVecKernelChecker via the PR contributed by @zhuhaozhe (PR link).

Current Flow with CppVecKernelChecker:
image

Desired Flow without the need of CppVecKernelChecker
image

Before diving into the limitations and how we’ve addressed them in the current CPP codegen, let’s start with a simple example to understand the basic vectorization strategy we initially implemented.

For the following PyTorch program “fn”:

import torch

@torch.compile
def fn(x, y):
    return x * x + y

x = torch.randn([10, 1024])[:, :1000]
y = torch.randn([10, 1024])[:, :1000]

fn(x, y)

We generate the following CPP vectorized kernel as follows:

extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       float* out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(10L); x0+=static_cast<long>(1L))
        {
            for(long x1=static_cast<long>(0L); x1<static_cast<long>(992L); x1+=static_cast<long>(16L))
            {
                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0)), 16);
                auto tmp2 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (1024L*x0)), 16);
                auto tmp1 = tmp0 * tmp0;
                auto tmp3 = tmp1 + tmp2;
                tmp3.store(out_ptr0 + static_cast<long>(x1 + (1000L*x0)));
            }
            #pragma omp simd simdlen(8) 
            for(long x1=static_cast<long>(992L); x1<static_cast<long>(1000L); x1+=static_cast<long>(1L))
            {
                auto tmp0 = in_ptr0[static_cast<long>(x1 + (1024L*x0))];
                auto tmp2 = in_ptr1[static_cast<long>(x1 + (1024L*x0))];
                auto tmp1 = decltype(tmp0)(tmp0 * tmp0);
                auto tmp3 = decltype(tmp1)(tmp1 + tmp2);
                out_ptr0[static_cast<long>(x1 + (1000L*x0))] = tmp3;
            }
        }
    }
}

Consider two loop variables corresponding to two dimensions of the input tensors. The generated code follows a load-compute-store pattern. Vectorization involves converting these load-compute-store operations into their vectorized counterparts. Specifically, vectorized loads and stores work best when data accesses are contiguous. To determine if vectorization is possible and at what loop level, we analyze the strides of each loop variable in the index expressions. A stride of 1 indicates contiguous access, a stride of 0 indicates scalar access, and any other value indicates non-contiguous access. In the example CPP code, the index expressions “x1 + 1024x0” and “x1 + 1000x0” show that accesses are contiguous along “x1,” so we vectorize the inner loop “x1”.

Now, let’s review the problems. Our last analysis identified ten categories of limitations that prevented vectorization. Below, we summarize these limitations. For more details, refer to the previous report:

  1. Lack of vectorization for data types beyond float32, bfloat16, float16, uint8, and their conversions.
  2. Issues related to indexing (index expressions and indirect indexing).
  3. Non-contiguous load and store operations.
  4. Unsupported operations, reduction types, and store modes.
  5. Inability to vectorize outer loops.

In the following sections, we’ll explain how we’ve addressed these issues in the current CPP codegen and highlight additional optimizations we’ve implemented.

Vectorization for More Data Types

While most AI workloads primarily use data types like float32, bfloat16, float16, and uint8, other data types like int64, int32, float64, and bool are also common, particularly for indexing and masks.

The major challenge in vectorization is aligning the vectorization factor (VF, or “tiling factor” in our implementation) with the fixed bit-width of vector registers and varying bit-widths of data types. Each generated kernel must work with a specific VF, even when handling different data types simultaneously. For all data types within the same kernel, the following constraint must be met:

Bit-width of data types * vectorization factor <= bit-width of vector registers (1)

Previously, we maximized VF for float32 in a vector register, e.g., 8 for AVX2 and 16 for AVX512, and used only half or a quarter of the registers for bfloat16/float16 and uint8, respectively. This approach becomes inefficient when supporting 64-bit data types like int64 or float64. To address this, @jgong5 introduced a new at::vec::VectorizedN abstraction that packs “N” vector registers (PR link), eliminating the constraint (1) mentioned above. For example, we can pack two int64 or double registers to match the VF of a float32 register. This design supports arbitrary VF. For AVX512, using VF=32 for float16/bfloat16 is more efficient than using VF=16, as implemented by @CaoE (PR link).

Supporting Booleans required special attention. Booleans are used to model masks in ternary operations (e.g., ops.where) and in masked operations like conditional loading for padding or windowed operations. @jgong5 created a convenient at::vec::VecMask class to represent vectorized Booleans, enabling basic logical Boolean operations and masked vector loading of any data type.

Additionally, @jgong5 introduced a new template function at::vec::convert for type conversion between VectorizedNs of the same number of elements (PR link). This unifies the codegen for ops.to_dtype across various vectorized data types.

Enhanced Stride Analysis on Indexes

As shown in the earlier example, analyzing the stride in the index involved in load/store operations is crucial for determining whether vectorization is possible. For simple affine index formulas, this analysis is straightforward. However, it becomes more complex when the expression is non-linear, containing FloorDiv or ModularIndexing, often resulting from reshape operations. In some cases, we can conclude that the stride is constant within vectorized ranges. For example, the formula 128*((x2//256)) + ModularIndexing(x2, 1, 128) has a stride of 1 for x2 when x2 is within the vectorized range (x2 = 16*a + b). Thus, we can still vectorize x2 even if the formula contains FloorDiv and ModularIndexing. In PR 117221, @jgong5 enhanced stride analysis for such cases, enabling vectorization for more scenarios.

Scalarization as a Fallback

Previously, we disabled vectorization for a kernel when we couldn’t vectorize any operations within it. Now, we apply scalarization (i.e., looping over an array of scalars followed by loading into the vector register) for individual operations that can’t be directly vectorized, such as non-contiguous load/store or atomic-add. This approach allows us to still benefit from vectorizing the remaining operations.

We modified the aforementioned example a little bit to make the load of “x” non-contiguous:

import torch

@torch.compile
def fn(x, y):
    return x * x + y

x = torch.randn([10, 2048])[:, :2000:2]
y = torch.randn([10, 1024])[:, :1000]

fn(x, y)

Then, as we can see from the generated code below, we scalarize the load of “x” (in_ptr0) meanwhile we can still vectorize the remaining parts of the code.

            for(long x1=static_cast<long>(0L); x1<static_cast<long>(992L); x1+=static_cast<long>(16L))
            {
                auto tmp0 =
                [&]
                {
                    __at_align__ std::array<float, 16> tmpbuf;
                    #pragma GCC unroll 16
                    for (long x1_inner = 0; x1_inner < 16; x1_inner++)
                    {
                        tmpbuf[x1_inner] = in_ptr0[static_cast<long>((2L*x1) + (2L*x1_inner) + (2048L*x0))];
                    }
                    return at::vec::Vectorized<float>::loadu(tmpbuf.data(), 16);
                }
                ()
                ;
                auto tmp2 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (1024L*x0)), 16);
                auto tmp1 = tmp0 * tmp0;
                auto tmp3 = tmp1 + tmp2;
                tmp3.store(out_ptr0 + static_cast<long>(x1 + (1000L*x0)));
            }

Vectorization Related to Indexing (Index Expressions and Indirect Indexing)

Index expressions are a key component of the Inductor loop IR, used in many places to model index computation for tasks like concatenation, padding, slicing, pooling, and reshaping. Index expressions can also contain variables called indirect indexing, loaded from tensors and used in operations like embedding lookups. Several improvements have been made to support index expressions and indirect indexing:

  1. Support for vectorization of integer types, addressed by “vectorizing more data types.”
  2. @jgong5 added arange to initialize the vector index for non-zero constant strides, based on enhanced stride analysis (PR link).
  3. In the same PR, index expressions with non-constant strides are scalarized.
  4. For indirect indexing, if the indirect index variables are independent of the vectorized loop variable, they can be ignored when checking contiguity. This is common in embedding lookups where the indirect indexing is based on individual embedding vectors that can be loaded into vectors. @jgong5 added CppCSEVariable to track dependencies of generated variables on loop variables, facilitating analysis of indirect indexing (PR link).

Supporting Vectorization at Any Loop Level

Originally, we only supported vectorization at the innermost loop level. While this covered many common cases, there are scenarios where we want to vectorize at outer loop levels, such as in backward computation where sum reduction occurs along a different dimension from the contiguous one. The loop nest data structure was extended to support vectorization and tail splitting at any loop level. Reduction vectorization can also occur at the reduction dimension (horizontal reduction) or the parallel dimension (vertical reduction).

Vectorization Support for More Ops

Since the last update, several operators that previously lacked vectorization support have now been addressed. Many of these operators are now directly supported by adding vectorization logic, achieved by invoking the corresponding C++ vectorized functions. Examples include the series of bitwise_X operators and the remainder operation, which have been enhanced with vectorization capabilities through recent pull requests (e.g., bitwise_X, remainder).

However, some operators, like randn and atomic_add, cannot be directly vectorized. For these cases, we apply a technique called scalarization, which allows us to still gain performance benefits by vectorizing the remaining operations while handling the challenging ones separately. More operators were recently added by @zhuhaozhe in this PR.

Other Optimizations: Vectorizing Tails, Optimized Welford Reduction, etc.

We have also focused on optimizing the handling of loop tails. Typically, loops are split into two parts: the main vectorized loop that handles the bulk of the data, and a tail loop that handles the remaining elements using scalar kernels. Although scalar kernels are generally adequate, there are cases where vectorizing the tail loop can provide additional performance benefits. For example, in a loop handling 31 elements, a non-vectorized tail loop would require 15 iterations, which is suboptimal. By using masked load and store operations and carefully managing reductions, we can extend the current vectorized code generation to handle tail loops more efficiently. PRs submitted by @jiayisunx have made this possible (PR1, PR2, PR3, PR4).

Another critical optimization is related to the Welford reduction, which is essential for computing means and variances in normalization operations. The core of this operation involves a costly division on Welford weights when calculating the mean incrementally. To optimize this, @CaoE introduced a change in this PR where the reciprocal of weights is cached, turning the division into a lighter multiplication. This has sped up all normalization operations that rely on Welford reduction.

Heuristics to Decide Vectorized Loop Level and Vectorization Factors

With the work mentioned above, we have achieved extensive vectorization of computations within compiled PyTorch programs. Even when vectorization at a particular loop level might result in non-contiguous data accesses or indexing, the code generation process can still fall back on scalarization, though this may lead to performance tradeoffs. The challenge now is determining the optimal loop level for vectorization and the corresponding vectorization factor (where VF=1 means using a scalar kernel). @leslie-fang-intel introduced a preliminary heuristics-based algorithm in these PRs (PR1, PR2) to make these decisions. The algorithm’s core idea is to evaluate the ratio of non-contiguity across all operations and determine if vectorization is worthwhile based on a predefined threshold. If vectorization is deemed appropriate, the vectorization factor is decided based on the data types and operation types. In the future, we are considering leveraging learning-based approaches to improve accuracy and generality.

Loop Optimizations

Outer Loop Fusion

There are several types of fusion that occur at different phases of the TorchInductor optimization process. Two typical examples are:

  • Pointwise fusion: This occurs during the lowering phase and fuses multiple pointwise operations with the same shapes.
  • Vertical fusion: During the scheduling phase, this type of fusion integrates the SchedulerNode of the consumer into its producer as a new FusedSchedulerNode when the reads of the consumer match the corresponding writes of the producer.

While vertical fusion is effective, it has limitations. It requires either identical dimension sizes for iteration and reduction or that only the consumer is a reduction node with the same dimension size as the producer. This limitation leaves some performance potential unexploited. For instance, consider the Softmax operation along the last dimension, which consists of several steps: max_reduction → sub_pointwise → exp_pointwise → sum_reduction → div_pointwise. With vertical fusion, three standalone loop regions would be generated:

  • max_reduction
  • sub_pointwise + exp_pointwise + sum_reduction
  • div_pointwise

Since these calculations are only performed along the innermost dimension, @leslie-fang-intel introduced the OuterLoopFusedSchedulerNode in this PR, which fuses the three standalone loops along the outer loop dimensions and performs the series of calculations at the innermost loop level. This approach improves performance through better data locality. Additionally, after fusing the outer loops, @leslie-fang-intel introduced LocalBuffer in this PR, which localizes the global buffer (originally used as a temporary storage) into a smaller local buffer, further enhancing data locality. These local buffers can even be reused across outer loop iterations.

Loop Split

The “channels last” memory format has been proven to deliver better performance, and we prefer using this format for CPP backend optimization, where the channel dimension becomes the vectorization dimension. However, challenges arise with operators grouped by the channel dimension, such as in Group Normalization. The input channels are divided into num_groups, and to index the mean and variance values for normalization, an expression like num_groups * x0 + (x2 // num_channels_per_group) is used. Here, x2 represents the channel dimension, leading to non-contiguous loads for mean and variance values.

To address this, we split the channel dimension into an outer dimension of size num_groups and an inner dimension of size num_channels_per_group. This allows us to load scalar mean and variance values along the inner dimension and broadcast them into the vectorization register. Current Inductor optimizes loop order through reordering and simplification. @jiayisunx has introduced a pass for loop body retracing with ad hoc loop order optimization, specifically addressing loop splitting, as seen in this PR.

Deterministic Parallel Reduction

Previously, we used OpenMP pragma to implement parallel reduction. Unfortunately, OpenMP parallel reduction is non-deterministic, which has caused multiple intermittent issues, as noted here and here. To address this problem, @zhuhaozhe added a two-pass deterministic parallel reduction in the CPP codegen. In this approach, the reduction work is first distributed among threads using #pragma omp for, with each thread reducing its assigned data to a local buffer. In the second pass, the main thread combines the reduction results from all threads. The final result remains stable as long as the number of threads remains the same.

Summary

In this post, we discussed recent efforts to enhance vectorization support and implement various loop optimizations in the TorchInductor CPU backend. These optimizations have consistently delivered performance improvements. We are excited to share these advancements as part of our presentation on the TorchInductor CPU backend at the upcoming PyTorch conference.

We are continually working to push the boundaries of performance and efficiency in the TorchInductor CPU backend, and we look forward to further developments in this area. Stay tuned for more updates!

2 Likes