Perf counters for fun and profit

I did a fun little perf deep-dive last night that I thought some of y’all might be interested in hearing about.

The context was two-fold: first, I was intrigued by some recent benchmarks showing performance improvements from jit-compiling ReLU, and second, I wanted to play with a skylake-avx512 server. So I thought, why not tackle both?

First I replicated the setup using the microbenchmark harness. Unfortunately I noticed that the relu comparison wasn’t totally fair to the aten kernel: it was doing out.copy_(a.relu_()) to avoid allocation, but this approach requires 3 trips to memory (read-modify-write for relu, followed by read and write for copy) instead of 2. When I patched the benchmark to do a fair comparison, the results were much closer:

relu 1024 1024
NNC: 0.164 s
ATen: 0.161 s
Speedup (ATen/NNC): 0.97 

Here we see ATen is very slightly faster (3%), which is disappointing but probably noise, right? Maybe if I re-run it several times, I’ll get some results where NNC is faster, and then I’ll feel reassured and go to bed happy.

Narrator: that didn’t happen.

That tiny slowdown persisted across dozens of runs, and while it’s probably not something that matters in any global sense, it bugs me. This is a simple kernel! Why is it slower?!?!

Right then. Let’s look at some assembly. Here’s NNC’s generated code, dumped via:

PYTORCH_JIT_LOG_LEVEL=">>llvm_codegen" python microbenchmarks.py

image

This goes on for pages. It’s unrolled the inner loop by a factor of 1024 and vectorized it using AVX-512 instructions. Verbose, but seems pretty good, right? Let’s check out ATen. I found the easiest way to dive into the assembly is to capture a perf profile of the interesting operator:

perf record -g -- python microbenchmarks.py
perf report

Then, you can select the top frame. Doing so is more annoying in open-source than fbcode because we build without -fno-omit-frame-pointers , so the call stacks are broken; luckily I happen to know that this highlighted function_ref call to threshold_kernel is the interesting one:

Press “a” to show annotated assembly and you can see the hot part of the op implementation clearly (press “o” to get the offsets so you can see where branches go):

image

If anything this looks worse! It’s doing a vcmp + vblend to implement relu, instead of just vmax! And it’s using 256-bit wide AVX2 instructions instead of the new AVX-512 hotness. What gives?

At this point I had a hypothesis: the unrolled code is huge, maybe we’re blowing out the cache? I copied the NNC assembly dump into a .S file and compiled with clang++ -c -o foo.o foo.S -march=native to look at the size of the generated code with objdump. But a bit of math showed the hot loop is only ~1000 bytes. Big by human standards but smaller than the 32KB instruction cache.

But wait! There’s an even sneakier uOP cache (or decoded stream buffer) that could be filling up. But no, it’s something like 1.5k uOPs: https://en.wikichip.org/wiki/intel/microarchitectures/skylake_(client).

I had one last theory. Maybe AVX-512 is actually a curse, not a blessing. I’ve heard rumors that AVX-512 can cause clock throttling. How could I figure out if that’s happening…?

Enter perf list . This command shows you alllll the hardware counters you could ever want and then several hundred more. Randomly searching the output for “512” I found:

core_power.lvl2_turbo_license
[Core cycles where the core was running in a manner where Turbo may be clipped to the AVX512 turbo schedule]

What are the odds that there’s a counter for exactly what I want?! Let’s see the stats for the ATen kernel with only AVX2:

% perf stat -e cycles,core_power.lvl2_turbo_license python microbenchmarks.py
8330778710      cycles
         0      core_power.lvl2_turbo_license

No throttling there! And now the NNC kernel:

% perf stat -e cycles,core_power.lvl2_turbo_license python microbenchmarks.py
8386784008      cycles
 583958024      core_power.lvl2_turbo_license

Boom. We’re spending 7% of cycles in a clock-throttled state. That seems like strong evidence that AVX-512-related downclocking is the source of our perf woes. Let’s check this by disabling AVX-512 with a small diff to NNC’s LLVM configuration:

diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.cpp b/torch/csrc/jit/tensorexpr/llvm_jit.cpp
index d9b726b902..94d527f6ff 100644
--- a/torch/csrc/jit/tensorexpr/llvm_jit.cpp
+++ b/torch/csrc/jit/tensorexpr/llvm_jit.cpp
@@ -50,11 +50,13 @@ static llvm::orc::JITTargetMachineBuilder makeTargetMachineBuilder() {
llvm::StringMap<bool> FeatureMap;
llvm::sys::getHostCPUFeatures(FeatureMap);
for (auto& Feature : FeatureMap) {
-    SubtargetFeatures.AddFeature(Feature.first(), Feature.second);
+    if (!Feature.first().startswith("avx512")) {
+      SubtargetFeatures.AddFeature(Feature.first(), Feature.second);
+    }
}
JTMB.setCodeGenOptLevel(llvm::CodeGenOpt::Default);
-  JTMB.setCPU(llvm::sys::getHostCPUName().str());
+  JTMB.setCPU("skylake");  // Host CPU is skylake-avx512
JTMB.addFeatures(SubtargetFeatures.getFeatures());
JTMB.getOptions().AllowFPOpFusion = llvm::FPOpFusion::Fast;

Now we can see the kernel uses only AVX2:

The new code induces no throttled clocks:

8014232049      cycles
         0      core_power.lvl2_turbo_license

And the overall performance is:

relu 1024 1024
NNC: 0.160 s
ATen: 0.162 s
Speedup (ATen/NNC):  1.02

Two whole percent faster! Woo! You can take that to the bank!

Hope y’all enjoyed this deep-dive. It was fun to explore, and I got to use an arcane hardware perf counter, which always makes for a good day.

5 Likes