Comparing the performance of 0.4.1 and master

A classic trope in sci-fi stories is the precursor civilization whose technology reached a pinnacle ages ago but is now long gone. With respect to framework overhead, PyTorch 0.4.1 is our precursor technology, achieving a seemingly similar feature set but with a lot less overhead. What happened? Can we restore the technology of ancients?

image

I did a comparative analysis of PyTorch 0.4.1 and master on a “bad multiply” implementation (that was implemented by summing in a loop), using perf and callgrind, and found the following:

  • Operator setup, specifically TensorIterator, is responsible for the majority of slowdown from PyTorch 0.4.1 and master. Yes, it is true, we have a fast path for contiguous tensors. It is not fast enough; 0.4.1 shows us exactly how much faster it can be.
  • Python and dispatcher overheads have gotten worse. No-op autograd (autograd processing on tensors with requires_grad=False) has gotten faster, likely because of the C++ Tensor-Variable merge.
  • Good performance is within reach without rewriting all of PyTorch: by rewriting TensorIterator (no small feat, but tractable, certainly), you have the opportunity to lop off 30% of runtime on this microbenchmark. I’m hoping that this analysis will help further galvanize action.

In the rest of this post, I’ll describe the experimental setup and the analysis I did to make these conclusions.

Experimental setup

The versions of PyTorch I tested were 0.4.1 (using tag v0.4.1, with some modifications) and master (ddf26816d3ca54ce7f3513f618fac93ce67d06e9, also with some modifications) on an Intel Xeon E5-2650 with gcc 9.3.0 and Anaconda Python 3.7.9. 0.4.1 was a bit crufty and needed some work to build it. Linked here are the 0.4.1 modifications and master modifications.

The most important change was adding -fno-omit-frame-pointer so that perf worked (this is a few percent instruction count regression, but I applied it to both 0.4.1 and master, so for comparative purposes it should be a wash.) In 0.4.1, I had to add this to both CMAKE_CXX_FLAGS in CMakeLists.txt as well as extra_compile_args/extra_link_args in setup.py (on master, we build everything with cmake, so only CMakeLists.txt is necessary). For future performance analysis, we should make it easier to build with this setting, and I filed #51151 to track this.

For the benchmark, I chose to replicate a benchmark where I ran addition in a loop, as addition is an operator that is heavily used in PyTorch and which has already been the subject of other performance analyses. (It is also a bit idiosyncratic, for reasons we will see later.) I used both callgrind (instruction counts) and perf (kernel call stack sampling) to collect data on where we were spending time inside of this benchmark. To make the data easier to analyze, I made sure to only collect data on the actual additions in a loop, and not include interpreter / shared library startup time, which often make ordinary perf profiles harder to interpret. I did this by toggling on the relevant profiler before entering the actual region to be profiled:

import torch
import subprocess
import os
import time
import signal

torch.set_num_threads(1)

def bad_multiply(x: torch.Tensor, k: int):
    output = torch.zeros_like(x)
    for _ in range(k):
        output = output + x
    return output

x = torch.randn(10**2)

if True:
    p = subprocess.Popen(['/usr/bin/perf', 'record', '-g', '-o', 'perf.data', '-p', str(os.getpid())])
    # give perf time to start running
    time.sleep(2)
    bad_multiply(x, 300000)
    p.send_signal(signal.SIGINT)
    p.wait()
else:
    # _valgrind_toggle didn't exist in 0.4.1, so I patched it in  
    torch._C._valgrind_toggle()
    bad_multiply(x, 300)
    torch._C._valgrind_toggle()

Interpreting perf

After profiling a program with perf, we can find out what percent of the time we were in any given function. Because the sampling frequency is fixed, we can extrapolate from total number of samples to roughly how long we spent in a function (Subject to sampling error! Functions that are very short but called a lot may not be adequately represented by perf.)

Let’s take a look at the call stacks recorded by perf for both programs. In fact, the Python binding code in PyTorch has not changed all that much and in both cases THPVariable_add is the entry point from the Python interpreter. The details of the call stacks are not too important, but I’ve put them below so you can get a feel for them.

Perf call stack (0.4.1) - 1383 total samples, 955 in add

- 46.12% torch::autograd::THPVariable_add
  - 40.45% torch::autograd::VariableType::add
     - 39.34% at::Type::add
        - 37.38% at::native::add
           - 35.44% at::Type::th_add
              - 30.22% torch::autograd::VariableType::s_th_add
                 - 20.84% at::CPUFloatType::s_th_add
                    - 10.71% THFloatTensor_cadd
                       - 5.91% GOMP_parallel@@VERSION
                          + 2.76% THFloatTensor_cadd._omp_fn.0
                            1.16% THFloatVector_cadd_AVX2
                            0.53% __kmpc_bound_num_threads
                       + 2.63% THFloatTensor_resizeNd
                    + 3.58% at::CPUFloatTensor::CPUFloatTensor
                      1.79% THFloatTensor_nElement
                      0.83% omp_in_parallel@@VERSION
                      0.78% THFloatTensor_isContiguous
                      0.58% at::checked_cast_tensor<at::CPUFloatTensor, at::TensorImpl>
                      0.53% at::CPUFloatTensor::CPUFloatTensor
                 + 2.18% torch::autograd::as_variable
                   0.77% torch::jit::tracer::IsTracing::operator()
                   0.53% torch::autograd::VariableType::unpack
                1.45% at::expand_outplace
  + 3.30% torch::PythonArgParser::raw_parse
    0.73% torch::autograd::utils::wrap
    0.58% torch::PythonArgs::scalar

Perf call stack (master) - 2976 total samples, 1712 in add

- 57.45% torch::autograd::THPVariable_add
   - 49.60% at::Tensor::add
      - 47.04% c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, c
         - 46.25% torch::autograd::VariableType::(anonymous namespace)::add_Tensor
            - 42.73% at::add
               - 41.65% c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&, c10::Scalar>
                  - 40.16% c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tens
                     - 39.42% at::(anonymous namespace)::wrapper_add_Tensor
                        - 28.32% at::meta::add_Tensor::meta
                           - 27.18% at::TensorIteratorBase::build_binary_op
                              - 23.88% at::TensorIteratorBase::build
                                 - 14.92% at::TensorIteratorBase::fast_set_up
                                    - 10.91% at::(anonymous namespace)::structured_add_out_functional::set_output
                                       - 9.29% at::native::empty_cpu
                                          - 8.14% at::detail::empty_cpu
                                             + 2.77% c10::DefaultCPUAllocator::allocate
                                             + 1.82% at::detail::make_tensor<c10::TensorImpl, c10::intrusive_ptr<c10::StorageImpl, c10::detail::intrusive_target_default_null_type<c10
                                      1.25% at::TensorIteratorBase::compute_fast_setup_type
                                   1.96% at::TensorIteratorBase::compute_types
                                   1.38% at::TensorIteratorBase::populate_operands
                                   1.35% at::TensorIteratorBase::compute_shape
                                0.74% at::TensorIteratorConfig::add_input
                        - 9.38% at::native::structured_add_out::impl
                           - 8.74% at::native::(anonymous namespace)::add_kernel
                              - 8.20% at::native::(anonymous namespace)::add_kernel(at::TensorIteratorBase&, c10::Scalar)::{lambda()#2}::operator()
                                 - 6.34% at::TensorIteratorBase::for_each
                                    - 6.04% at::TensorIteratorBase::for_each
                                       - 5.57% at::TensorIteratorBase::serial_for_each
                                          - 1.65% c10::function_ref<void (char**, long const*, long, long)>::callback_fn<at::TensorIteratorBase::for_each(c10::function_ref<void (char
                                               c10::function_ref<void (char**, long const*, long)>::callback_fn<at::native::(anonymous namespace)::cpu_kernel_vec<true, at::native::(a
                                            1.05% at::TensorIteratorBase::get_data_ptrs
                                            0.98% at::TensorIteratorBase::get_strides
              0.98% at::AutoNonVariableTypeMode::AutoNonVariableTypeMode
   + 3.04% torch::PythonArgParser::raw_parse
     0.81% torch::PyWarningHandler::~PyWarningHandler
     0.71% THPVariable_Wrap 

To setup a comparison of overheads, we will have to do some work:

  1. First, we need to draw a correspondence between functions in 0.4.1 and functions in master. In some cases this is easy, because the function exists in both cases (e.g., THPVariable_add), but in most cases it is not so easy to do, and requires some knowledge of what each function in the call stack does. Fortunately, we don’t have to do this for every function in the stack, only the important ones where the functionality being done changes.
  2. Second, we need to normalize the percentages so that we can compare them (remember, the percentage says what percent time in that invocation of perf we spent in the function). We’ll do this by multiplying the percentage with the total number of samples (at a fixed frequency!) perf made in that run.

I did the analysis this spreadsheet. I took a subset of the functions in the perf call stack and classified each into a rough “bucket” of overhead. The buckets I chose were:

  • No-op autograd (VariableType; but not torch::autograd::VariableType::add!). This is the overhead attributed to going through the autograd system. Because the benchmark in question involves all requires_grad=False tensors, this is effectively a no-op, and you would hope the overhead here to be zero. We’ve gotten better in master here, because we no longer have to allocate a new Variable tensor to wrap tensors in autograd even when nothing else has happened.
  • Computation (THFloatVector_cadd_AVX2, cpu_kernel_vec<add_kernel>). These are the functions that actually do the vectorized addition. Notice that they are very small portion of the compute; even 0.4.1 has a lot of overhead. This makes sense: even when you don’t work particularly hard, if you dump all support for features and error checking, you can get 1000x speedups to addition.
  • Python (THPVariable_add). These are the functions that take in PyObjects from Python and parse them.
  • Dispatcher. I used to it to attribute any place in the code where we were just trying to figure out “where to go”. For example, torch::autograd::VariableType::add doesn’t actually do any autograd stuff; instead, it calls into a composite function that checks if any arguments are sparse and branches off computation to alternate kernels in that case. In master, we literally have a component called the dispatcher.
  • Operator setup. This is everything related to do an operator that isn’t actual computation, or accounted for in any of the other categories. Error checking, dtype checking, output allocation, etc. all fall under this umbrella.

When you compare the numbers, you can see that the most extra time, relative to 0.4.1, is spent in operator setup.

Aside: Callgrind works too!

In the spreadsheet, you will see that I also have instruction counts taken from callgrind (using valgrind --callgrind-out-file=callgrind.out --tool=callgrind --dump-line=yes --dump-instr=yes —instr-atstart=yes --collect-atstart=no) and then analyzed the data with callgrind_annotate callgrind.out --inclusive=yes. In fact, I did the callgrind experiments first, because it was easier to setup than perf (don’t need to take care of -fno-omit-frame-pointer, experiments are 100% deterministic). There is not really much to say: as you can see, although there is a little distortion, the callgrind numbers come to essentially the same conclusions as the perf numbers.

Why is 0.4.1 operator setup fast?

We could ask how we can speed up operator setup, given the numbers above. But that would lead us into a long discussion about TensorIterator internals that I think would obscure rather than enlighten. Instead, I want to briefly walk you through 0.4.1’s operator setup, and give some intuition for why 0.4.1 operator setup is fast: it does less.

Operator setup in 0.4.1 is split up into three parts. First, in at::Type::th_add, broadcasting is handled by allocating new expanded tensors if broadcasting is necessary (108 samples):

Tensor Type::th_add(const Tensor & self, const Tensor & other, Scalar alpha) const {
    const DeviceGuard device_guard(self);
    Tensor b_self, b_other;
    std::tie(b_self, b_other) = expand_outplace(self, other, "th_add");
    return s_th_add(b_self, b_other, alpha); // ***
}

Next, at::CPUFloatType::s_th_add is responsible for allocating the output tensor and checking the dtypes of the input tensors (210 samples):

Tensor CPUFloatType::s_th_add(const Tensor & self, const Tensor & other, Scalar alpha) const {
    const DeviceGuard device_guard(self);
    auto result_ = new CPUFloatTensor(context);
    auto result = Tensor(result_, false);
    auto self_ = checked_cast_tensor<CPUFloatTensor>(self.pImpl,"self",1, false);
    auto alpha_ = alpha.toFloat();
    auto other_ = checked_cast_tensor<CPUFloatTensor>(other.pImpl,"other",3, false);
    THFloatTensor_cadd(result_->tensor, self_->tensor, alpha_, other_->tensor);  // ***
    result_->maybeScalar(self_->isScalar() && other_->isScalar());
    return result;
}

Finally, in THFloatTensor_cadd, we test if the two input tensors have the same size, that they’re both contiguous, and that we are not doing an inplace operation. When all of these hold, we are immediately at the contiguous kernel. (198 samples) That’s it! (In the code sample below, I’ve erased all portions of the code that don’t get exercised in this microbenchmark.)

void THTensor_(cadd)(THTensor *r_, THTensor *t, real value, THTensor *src)
{
  THTensor_(resizeAs)(r_, t);
  int64_t r_Size = THTensor_(nElement)(r_);
  int64_t srcSize = THTensor_(nElement)(src);
  int r_Contig = THTensor_(isContiguous)(r_);
  int tContig = THTensor_(isContiguous)(t);
  int srcContig = THTensor_(isContiguous)(src);
  int serial_path = 0;
  if (srcSize == r_Size){
    if (r_Contig && tContig && srcContig) {
      if(r_ == t) { ...
      } else {
        TH_TENSOR_APPLY3_CONTIG(real, r_, real, t, real, src, THVector_(cadd)(r__data, t_data, src_data, value, r__len););
      }
    } else { ... }
  } else { ... }
  if (serial_path) { ... }
}

I could explain all of this code in five minutes.

Conclusion

TensorIterator came into existence for a very good reason: to handle all non-contiguous cases, in TH, there was quite a lot of boilerplate you had to write (the code I’ve elided in the ellipses is nontrivial!) TensorIterator made it possible to define a new binary operator that handled all of the edge cases PyTorch users expect to work correctly in only a dozen lines of code. But the abstraction came at a cost: in particular, as we added new features to TensorIterator, we lost the capacity to reason simply about what is actually happening in it. This is why I argue that a rewrite of TensorIterator is in order: we must bring this transparency back. As Zachary DeVito always says, if you want to make a system fast, first you have to make it simple.

Acknowledgements. Thank you Gregory Chanan and Basil Hosmer for pushing for this comparison to actually happen in the first place, Scott Wolchok for putting up with my newbie perf questions, Sam Gross for writing TensorIterator in the first place (in this post, I rag on it a lot, but it has really served us well for a long time without a major rewrite), and Taylor Robie for driving a lot of benchmarking work which I built on top of for this post.

13 Likes