TorchInductor CPP Backend Vectorization Status Analysis

TL;DR: This post examines vectorization optimization in Inductor CPP backend for FP32 training and inference of 150 benchmark models. 90% of inference kernels and 71% of training kernels are vectorized. The remaining non-vectorized kernels are analyzed in 10 categories, highlighting the next steps to improve vectorization coverage: index-related operations, int64 support, vertical reduction, vectorization with fallback, and more.

In a recent post, @EikanWang shared the latest FP32 inference performance numbers of the TorchInductor CPU backend, and explained how optimizations were added, including vectorizing the Inductor CPP kernels. This technique greatly improved Inductor CPU’s performance. In this post, I will share statistics on the status of vectorization optimizations from real benchmark models, and discuss what is not yet vectorized, why, and possible approaches to address these issues.

Statistics of vectorized kernels

We assessed 150 models using the FP32 benchmarks for torchbench, huggingface, and timm. In terms of inference, a total of 28,185 CPP kernels were generated, with 25,579 (90%) of them being vectorized, while the remaining 10% were scalar. As for training, 103,084 kernels were generated, with 73,909 (71%) being vectorized and 29% not vectorized. The results indicate that the vectorization of inference kernels is quite impressive, while there is still some work to be done in training kernels since we just started to work on the training. In the following section, I will analyze the non-vectorized kernels with specific examples to identify the most critical missing features.

Dive into non-vectorized kernels

The CppVecKernelChecker class and CppTile2DKernelChecker class in CPP codegen implement specific rules to determine the feasibility of vectorizing a kernel. A recent pull request includes debug logs that help identify why a kernel may fail to vectorize by providing insight into the conditions that were not met. The information has been grouped into 10 categories to help understand the reasons for vectorization failure. The two charts below illustrate the frequency of occurrence of each category for three different benchmarks, one for inference and the other for training.

image

image

  1. index_expr

The main limitation for vectorization is the absence of the support related to indices. Presently, the computation on indices is not vectorized, except for cases where a scalar index is broadcasted as a vector, which must remain constant with respect to the loop variable being vectorized. However, this check seems to prevent the vectorization of most index_expr.
Below is an example from XGLMForCausalLM:

        #pragma omp for 
        for(long i0=0; i0<1024; i0+=1)
        {
            #pragma GCC ivdep
            for(long i1=0; i1<1024; i1+=1)
            {
                auto tmp0 = static_cast<long>(i1);
                auto tmp1 = static_cast<long>(i0);
                auto tmp2 = static_cast<long>(1);
                auto tmp3 = tmp1 + tmp2;
                auto tmp4 = tmp0 < tmp3;
                auto tmp5 = static_cast<float>(0.0);
                auto tmp6 = -std::numeric_limits<float>::infinity();
                auto tmp7 = tmp4 ? tmp5 : tmp6;
                out_ptr3[i1 + (1024*i0)] = tmp7;
            }
        }

In this context, “i1” serves as both the inner-most loop variable and an index expression. To enable vectorization on “i1”, we can set the initialization of “tmp0” with Vectorized::arrange. It’s important to note that this process also necessitates the ability to convert integer masks into floating masks, which is essential for creating a valid “blendv” operation for “where” that defines “tmp7”.
There are more complicated cases (less frequently occurred than the previous one), e.g., an example from hf_BigBird below. Even though there are complex indices involving index_expr and computation and data loads that make vectorization challenging, there is still an advantage to vectorizing on i2 since the four stores are continuous along that axis. However, we may need to implement a “vectorization with fallback” mechanism to incorporate both scalar and vectorized code into the same loop body. The pull request found at [Inductor] simplify CPP backend Tile2D code by jgong5 · Pull Request #97626 · pytorch/pytorch · GitHub is a part of this effort.

        #pragma omp for 
        for(long i0=0; i0<12; i0+=1)
        {
            #pragma GCC ivdep
            for(long i1=0; i1<192; i1+=1)
            {
                #pragma GCC ivdep
                for(long i2=0; i2<64; i2+=1)
                {
                    auto tmp0 = in_ptr1[(33*i0) + (33*((i2 + (64*i1)) / 135168)) + ((i2 + (64*i1)) / 4096)];
                    auto tmp9 = in_ptr1[30 + (33*i0) + (33*((122880 + i2 + (64*i1)) / 135168)) + ((i2 + (64*i1)) / 4096)];
                    auto tmp1 = static_cast<long>((33*i0) + (33*((i2 + (64*i1)) / 135168)) + ((i2 + (64*i1)) / 4096));
                    auto tmp2 = static_cast<long>(33);
                    auto tmp3 = ((tmp1 < 0) != (tmp2 < 0) ? (tmp1 % tmp2 != 0 ? tmp1 / tmp2 - 1 : tmp1 / tmp2) : tmp1 / tmp2);
                    auto tmp4 = static_cast<long>(13);
                    auto tmp5 = tmp3 * tmp4;
                    auto tmp6 = tmp0 + tmp5;
                    auto tmp7 = in_ptr0[i2 + (64*((tmp6 / 13) % 12)) + (768*(i1 % 64)) + (49152*(tmp6 % 13))];
                    auto tmp8 = in_ptr2[i2 + (64*((tmp6 / 13) % 12)) + (768*(i1 % 64)) + (49152*(tmp6 % 13))];
                    auto tmp10 = static_cast<long>(30 + (33*i0) + (33*((122880 + i2 + (64*i1)) / 135168)) + ((i2 + (64*i1)) / 4096));
                    auto tmp11 = ((tmp10 < 0) != (tmp2 < 0) ? (tmp10 % tmp2 != 0 ? tmp10 / tmp2 - 1 : tmp10 / tmp2) : tmp10 / tmp2);
                    auto tmp12 = tmp11 * tmp4;
                    auto tmp13 = tmp9 + tmp12;
                    auto tmp14 = in_ptr0[i2 + (64*((tmp13 / 13) % 12)) + (768*(i1 % 64)) + (49152*(tmp13 % 13))];
                    auto tmp15 = in_ptr2[i2 + (64*((tmp13 / 13) % 12)) + (768*(i1 % 64)) + (49152*(tmp13 % 13))];
                    out_ptr6[i2 + (64*i1) + (28672*i0)] = tmp7;
                    out_ptr7[i2 + (64*i1) + (28672*i0)] = tmp8;
                    out_ptr8[i2 + (64*i1) + (28672*i0)] = tmp14;
                    out_ptr9[i2 + (64*i1) + (28672*i0)] = tmp15;
                }
            }
        }
  1. to_dtype

At present, we don’t provide support for vectorization of int64 and double data types. Supporting vectorization for these types requires matching the number of vector lanes if we also want to vectorize float32 and/or int32 simultaneously. To accomplish this, we may need to use two vector variables to hold int64 or double vectors to match one float32 or int32 vector variable. The problems with the “to_dtype” function are specifically connected to these two data types. In the majority of real benchmarks, int64 and double are commonly utilized by the calculation of scalar indices, making vectorization unnecessary.
Below is an example from hrnet_w18. In this particular scenario, we don’t need to vectorize the int64 and double indices since they have no relation to i3, which is the index we want to vectorize. Hence, it suffices to leave them as scalar and not perform vectorization on them.

        #pragma omp for 
        for(long i0=0; i0<128; i0+=1)
        {
            #pragma GCC ivdep
            for(long i1=0; i1<56; i1+=1)
            {
                #pragma GCC ivdep
                for(long i2=0; i2<56; i2+=1)
                {
                    #pragma GCC ivdep
                    for(long i3=0; i3<18; i3+=1)
                    {
                        auto tmp0 = in_ptr0[i3 + (18*i2) + (1008*i1) + (56448*i0)];
                        auto tmp1 = static_cast<long>(i1);
                        auto tmp2 = static_cast<double>(tmp1);
                        auto tmp3 = static_cast<double>(1);
                        auto tmp4 = tmp2 * tmp3;
                        auto tmp5 = static_cast<double>(0);
                        auto tmp6 = tmp4 + tmp5;
                        auto tmp7 = static_cast<float>(tmp6);
                        auto tmp8 = static_cast<float>(0.5);
                        auto tmp9 = tmp7 * tmp8;
                        auto tmp10 = static_cast<long>(tmp9);
                        auto tmp11 = static_cast<long>(i2);
                        auto tmp12 = static_cast<double>(tmp11);
                        auto tmp13 = tmp12 * tmp3;
                        auto tmp14 = tmp13 + tmp5;
                        auto tmp15 = static_cast<float>(tmp14);
                        auto tmp16 = tmp15 * tmp8;
                        auto tmp17 = static_cast<long>(tmp16);
                        auto tmp18 = in_ptr1[i3 + (18*tmp17) + (504*tmp10) + (14112*i0)];
                        auto tmp19 = tmp0 + tmp18;
                        auto tmp20 = tmp19 * (tmp19>0);
                        out_ptr0[i3 + (18*i2) + (1008*i1) + (56448*i0)] = tmp20;
                    }
                }
            }
        }
  1. indirect_indexing

We exclude all indirect indexing cases from vectorization, but upon observation, we find that we can still vectorize some cases when the indirect index variables remain constant with respect to the loop variables we want to vectorize. One instance of this can be seen in the “dtype” section, where the variable “tmp18” is a load with indirect indices. However, these indices are only dependent on “i1” and “i2” and not on “i3” which is the loop variable we want to vectorize. To obtain this information, we would require an analysis pass to track the relationships between the variables and each loop variable.

  1. unsupported masked

We vectorize the kernel containing “masked” op conservatively and don’t allow any actual computation inside it. This means that cases with nested masked bodies or computations within the “masked” element cannot be vectorized, such as the one found in jx_nest_base. However, in most cases like the example below, enabling vectorization for computation would not pose any issue.

        #pragma omp for 
        for(long i0=0; i0<32; i0+=1)
        {
            #pragma GCC ivdep
            for(long i1=0; i1<57; i1+=1)
            {
                #pragma GCC ivdep
                for(long i2=0; i2<57; i2+=1)
                {
                    #pragma GCC ivdep
                    for(long i3=0; i3<256; i3+=1)
                    {
                        auto tmp0 = static_cast<long>(i1);
                        auto tmp1 = static_cast<long>(56);
                        auto tmp2 = tmp0 < tmp1;
                        auto tmp3 = static_cast<long>(i2);
                        auto tmp4 = tmp3 < tmp1;
                        auto tmp5 = tmp2 & tmp4;
                        auto tmp6 = [&]
                        {
                            auto tmp7 = in_ptr0[i3 + (256*i2) + (14336*i1) + (802816*i0)];
                            auto tmp8 = in_out_ptr0[i2 + (56*i1) + (3136*i0)];
                            auto tmp9 = tmp7 - tmp8;
                            auto tmp10 = out_ptr1[i2 + (56*i1) + (3136*i0)];
                            auto tmp11 = static_cast<float>(256);
                            auto tmp12 = tmp10 / tmp11;
                            auto tmp13 = static_cast<float>(1e-06);
                            auto tmp14 = tmp12 + tmp13;
                            auto tmp15 = 1 / std::sqrt(tmp14);
                            auto tmp16 = tmp9 * tmp15;
                            auto tmp17 = in_ptr1[i3];
                            auto tmp18 = tmp16 * tmp17;
                            auto tmp19 = in_ptr2[i3];
                            auto tmp20 = tmp18 + tmp19;
                            return tmp20;
                        }
                        ;
                        auto tmp21 = tmp5 ? tmp6() : -std::numeric_limits<decltype(tmp6())>::infinity();
                        out_ptr2[i3 + (256*i2) + (14592*i1) + (831744*i0)] = tmp21;
                    }
                }
            }
        }
  1. unsupported dtype in load/store

Similarly to the “dtype” case, the int64 and double vectorized data types are unsupported. Supporting vectorization for these types requires matching the number of vector lanes if we also want to vectorize float32 and/or int32 simultaneously. To accomplish this, we may need to use two vector variables to hold int64 or double vectors to match one float32 or int32 vector variable.
Based on real benchmarks, the majority of cases where vectorization is lacking is due to the absence of int64 vectorization support. These cases fall into two main scenarios: 1) int64 is loaded for indirect indexing, and 2) int64 is loaded for computation, as illustrated in the examples from fastNLP_Bert below. In the first example, we do not need to vectorize the int64 variables “tmp0”, “tmp2” and “tmp5” since the loaded variables are invariant to “i1” which is being vectorized. However, int64 vectorization support is necessary for the second example.

        // We do not need to vectorize tmp0, tmp2 and tmp5 since they are invariant to i1
        #pragma omp for 
        for(long i0=0; i0<475; i0+=1)
        {
            {
                float tmp8 = 0;
                for(long i1=0; i1<768; i1+=1)
                {
                    auto tmp0 = in_ptr0[i0];
                    auto tmp5 = in_ptr3[i0];
                    auto tmp1 = in_ptr1[i1 + (768*tmp0)];
                    auto tmp2 = static_cast<long>(i0);
                    auto tmp3 = in_ptr2[i1 + (768*tmp2)];
                    auto tmp4 = tmp1 + tmp3;
                    auto tmp6 = in_ptr4[i1 + (768*tmp5)];
                    auto tmp7 = tmp4 + tmp6;
                    out_ptr0[i1 + (768*i0)] = tmp7;
                    tmp8 += tmp7;
                }
                out_ptr1[i0] = tmp8;
            }
        }
        // vectorization on tmp4 is needed since it is variant to i1
        #pragma omp for 
        for(long i0=0; i0<5700; i0+=1)
        {
            {
                float tmp10 = -std::numeric_limits<float>::infinity();
                for(long i1=0; i1<475; i1+=1)
                {
                    auto tmp0 = in_ptr0[i1 + (475*i0)];
                    auto tmp4 = in_ptr1[i1];
                    auto tmp1 = static_cast<float>(8.0);
                    auto tmp2 = tmp0 / tmp1;
                    auto tmp3 = static_cast<float>(1.0);
                    auto tmp5 = static_cast<float>(tmp4);
                    auto tmp6 = tmp3 - tmp5;
                    auto tmp7 = static_cast<float>(-10000.0);
                    auto tmp8 = tmp6 * tmp7;
                    auto tmp9 = tmp2 + tmp8;
                    tmp10 = std::max(tmp10, tmp9);
                }
                out_ptr0[i0] = tmp10;
            }
        }

“Double” is used as scalar in all the cases we encounter, which makes the vectorization unnecessary.
In addition to int64 and double, we only support vectorized bool and uint8 when they are used as masks. There are small number of cases where uint8 is stored as bool, e.g., an example from DebertaForQuestionAnswering. Vectorization on them would be straightforward since their type sizes match, meanwhile we have to be careful if there are types of different sizes in the same kernel.

        #pragma omp for 
        for(long i0=0; i0<2097152; i0+=1)
        {
            auto tmp0 = in_ptr0[i0];
            auto tmp1 = static_cast<bool>(tmp0);
            auto tmp2 = tmp1 == 0;
            out_ptr0[i0] = tmp2;
        }
  1. non-contiguous load/store (excluding indirect indexing)

CppTile2DKernel with 2d transposition support has already vectorized some of the non-contiguous load/store. However, there are still two main cases that have not been covered yet. The first case occurs frequently in most models during training backward, where the non-contiguous load/store happens on the inner-most reduction loop while being contiguous on an outer parallel loop, which is known as vertical reduction.

    #pragma GCC ivdep
    for(long i0=0; i0<1000; i0+=1)
    {
        {
            float tmp1 = 0;
            for(long i1=0; i1<128; i1+=1)
            {
                auto tmp0 = in_ptr0[i0 + (1000*i1)];
                tmp1 += tmp0;
            }
            out_ptr0[i0] = tmp1;
        }
    }

The second case involves complicated indexing formulas such as floor division (//) or ModularIndexing, and in order to achieve maximum vectorization scope, we must rely on “vectorization with fallback”.

  1. unsupported ops

We currently do not support vectorization for some operations such as bitwise_and, bitwise_or, bitwise_xor, logical_not, remainder, truediv, among others. However, most of these operations should be easy to support. Although there are a few instances of “randn” which are difficult to vectorize, they occur infrequently.

  1. unsupported reduction

The main reason for the lack of support for reduction operations is primarily attributed to the absence of support for int64 vectorization, e.g., from fastNLP_Bert:

    #pragma GCC ivdep
    for(long i0=0; i0<6; i0+=1)
    {
        {
            long tmp2 = 0;
            for(long i1=0; i1<474; i1+=1)
            {
                auto tmp0 = out_ptr0[i1 + (474*i0)];
                auto tmp1 = static_cast<long>(tmp0);
                tmp2 += tmp1;
            }
            out_ptr2[i0] = tmp2;
        }
    }
  1. unsupported constant dtype

Vectorization is not implemented for constant of data type uint8 or bool. They happen less frequently and can be handled as low priority.

  1. unsupported store modes

“atomic_add” cannot be vectorized. Still, we can do vectorization with fallback to maximize the performance, e.g., from AlbertForMaskedLM, we are able to vectorize all the ops except for atomic_add which can be put into an inner loop.

    for(long i0=0; i0<512; i0+=1)
    {
        #pragma GCC ivdep
        for(long i1=0; i1<128; i1+=1)
        {
            auto tmp0 = in_ptr6[i0];
            auto tmp4 = out_ptr4[i1 + (128*i0)];
            auto tmp5 = out_ptr4[65536 + i1 + (128*i0)];
            auto tmp7 = out_ptr4[131072 + i1 + (128*i0)];
            auto tmp9 = out_ptr4[196608 + i1 + (128*i0)];
            auto tmp1 = static_cast<long>(-1);
            auto tmp2 = tmp0 == tmp1;
            auto tmp3 = static_cast<float>(0);
            auto tmp6 = tmp4 + tmp5;
            auto tmp8 = tmp6 + tmp7;
            auto tmp10 = tmp8 + tmp9;
            auto tmp11 = tmp2 ? tmp3 : tmp10;
            atomic_add(&out_ptr7[i1 + (128*tmp0)], tmp11);
        }
    }

Summary

Vectorization optimization has been a significant improvement in Inductor CPP backend. The analysis shows that a big portion of kernels have already been vectorized. The remaining non-vectorized kernels have been categorized into 10 categories with suggested features as the next steps. With them, Inductor CPU’s performance can continue to be enhanced, making it a more efficient and effective tool for deep learning applications.

3 Likes