Performance Comparison between Torch.Compile and APEX optimizers

TL;DR

  • Compiled Adam outperformed SOTA hand-optimized APEX optimizers on all benchmarks; 62.99% on Torchbench, 53.18% on HuggingFace, 142.75% on TIMM and 88.13% on BlueBerries
  • Compiled AdamW performed similarly with up to a 2x improvement on TIMM
  • Compiled SGD had at least a 30% speedup on all benchmarks.
  • Some models are particularly sensitive to overheads, for these models the cpp_wrapper or cudagraphs can improve performance by up to 2x or more.

Background

Currently, torch.compile supports compiling 11/13 optimizers into optimized foreach kernels. In order to drive adoption it is necessary to show how compiled optimizers can reach SOTA performance, outperforming APEX, the NVidia fused optimizer kernel library which is the main performant alternative to PyTorch optimizers. APEX supports three optimizers in common with torch.compile; SGD, AdamW, and Adam. These optimizers are evaluated on models from HuggingFace, TorchBench, and TIMM benchmarks.

Speedup Results (Raw Data) (script)

Analysis

The different test suites had varying performance characteristics due to model size. Since APEX and torch.compiled kernels are so fast, TorchBench models along with smaller TIMM models were particularly sensitive to overheads. For example, the profile for compiled Adam optimizer on dcgan is shown below.


Figure 1: dcgan torch.compile profile

In this profile it is almost entirely dominated by overhead; the triton kernel launches are extremely small in the lower right hand corner. After further investigation, the overhead was due to three components: 1) guarding 2) cudagraphs checks 3) cudagraph launches.

Guards are conditions that torch.compile checks which must be run on each invocation in order to check that the compiled artifact is valid. In order to lower this overhead, three approaches were used. To improve guard performance in general, Animesh Jain moved guard evaluation to C++. On top of this, since optimizer state tensors do not change across iterations, tensor shape guards were switched to guards on the data pointers, which is much faster. Finally, there were redundant guards due to the repetitive structure of the optimizer that were able to be removed entirely.

For 2) cudagraphs checks the data pointers of parameter tensors are static across invocations in order to guarantee the correctness of the recorded cudagraph. These checks were initially implemented in python but after moving to C++, performance was improved significantly.

Initially, this overhead resulted in compiled Adam being initially 20% slower on torchbench than APEX even with cudagraphs, but after the above changes, compiled Adam outperformed APEX by up to 2x on TIMM. Overall, Huggingface appeared to give the most realistic results, as those kernels operated on tensors large enough that the kernel runtime begins to dominate the total runtime, which is the more typical scenario for mainstream models. To illustrate this, the profile for XGLMForCausalLM from HuggingFace is shown below.


Figure 2: XGLMForCausalLM torch.compile profile

Note on SGD

SGD is an interesting test case for overhead comparison with Eager. Because SGD is a single memory-bound kernel in Eager, there are not any vertical fusion optimization opportunities, which is illustrated by lack of speedup of ~.96 on Torchbench and HuggingFace with generic Torch.Compile.

Why is Torch.Compile faster?

As APEX optimizers are highly optimized, it is important to address the question of why torch.compile is faster. There could be a multitude of reasons related to python overhead, (torch.compile traces this away), better code generation from triton, or other optimizations. Below are the profiles generated for XGLMForCausalLM for torch.compile and APEX respectively.


Figure 3: Torch.Compile XGLMForCausalLM Profile


Figure 4: APEX XGLMForCausalLM Profile

From these profiles, the answer to this question is that it is a combination of both overhead reduction (by ~2ms) and the kernels themselves being faster by about 4ms due to better performance of the triton code.

TIMM Performance

TIMM performance was surprisingly slow without cpp wrapper or cudagraphs. For example on ghostnet_100 adding the cpp wrapper improves performance by 2x. The profile is shown below.


Figure 5: ghostnet_100 torch.compile profile with launch overhead

Looking at this profile, there are obvious gaps between kernel launches. This is due to extracting argument pointers and launch overhead from the inductor-generated python code which launches the generated triton kernels. This problem is unique to foreach kernels because they have a much larger number of arguments than the typical kernel in order to maximize any available memory bandwidth. In addition to having a large number of arguments, the arguments themselves are very small, resulting in kernels that run so fast that they finish before the next kernel can be launched, decreasing GPU utilization.

After enabling the cpp wrapper, inductor generates the launcher code in C++ instead of python, which lowers this overhead significantly, resulting in the following.


Figure 6: ghostnet_100 torch.compile profile w/ cpp_wrapper

This drastically reduces the launch overhead per kernel, resulting in a > 2x speedup.

Overall with these optimizations the compiled optimizers outperform the APEX optimizers on HuggingFace, TIMM, and TorchBench.

Next Steps

  • Performance can be improved further with autotuning tile sizes
  • LRScheduler compatibility with torch.compile’d optimizers

Acknowledgements

Thanks to Animesh Jain for guard optimizations, Elias Ellison for help with optimizing cudagraphs, Jason Ansel for design feedback, and Jane Xu for reviewing this post.

4 Likes

Note On torch.compile Generated Code

I’ve been receiving a lot of questions on what exactly torch.compile generates when compiling the optimizer. For context the post on foreach kernels contains the main mechanism through which inductor generates large horizontally fused kernels like those used in the optimizer. torch.compile uses this codegen to compile the traced optimzier graph into two kernels, the first of which performs the increment on the step values, and the second of which is a fully vertically and horizontally fused optimizer kernel.

For example, consider the following optimizer test code code:

import torch

params = [torch.rand(4, 5, device="cuda") for _ in range(3)]
kwargs = {"foreach": True, "capturable": True}
optimizer = torch.optim.Adam(params, **kwargs)

for p in params:
    p.grad = torch.ones_like(p)

opt_step = torch.compile(optimizer.step)

opt_step()

running this with TORCH_LOGS="output_code" will show the full code that inductor generates.

For the purposes of this note, the two kernels I described above are:

The increment kernel:

@triton_heuristics.foreach(
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'triton_for_fused_0', 'backend_hash': '63937d058519033f995f0585a4aab6c8c8898fe6839dd14ce1536da9b902b160', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2):
    xpid = tl.program_id(0)
    XBLOCK: tl.constexpr = 1024
    if xpid >= 0 and xpid < 1:
        xpid_offset = xpid - 0
        xnumel = 1
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        tmp0 = tl.load(in_ptr0 + (0))
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
        tmp2 = 1.0
        tmp3 = tmp1 + tmp2
        tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp3, None)
    elif xpid >= 1 and xpid < 2:
        xpid_offset = xpid - 1
        xnumel = 1
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        tmp4 = tl.load(in_ptr1 + (0))
        tmp5 = tl.broadcast_to(tmp4, [XBLOCK])
        tmp6 = 1.0
        tmp7 = tmp5 + tmp6
        tl.store(out_ptr1 + (tl.full([XBLOCK], 0, tl.int32)), tmp7, None)
    elif xpid >= 2 and xpid < 3:
        xpid_offset = xpid - 2
        xnumel = 1
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        tmp8 = tl.load(in_ptr2 + (0))
        tmp9 = tl.broadcast_to(tmp8, [XBLOCK])
        tmp10 = 1.0
        tmp11 = tmp9 + tmp10
        tl.store(out_ptr2 + (tl.full([XBLOCK], 0, tl.int32)), tmp11, None)
    else:
        pass
''', device_str='cuda')

And then the full optimizer kernel representing the main compute of Adam:

@triton_heuristics.foreach(
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: '*fp32', 11: '*fp32', 12: '*fp32', 13: '*fp32', 14: '*fp32', 15: '*fp32', 16: '*fp32', 17: '*fp32', 18: '*fp32', 19: '*fp32', 20: '*fp32', 21: '*fp32', 22: '*fp32', 23: '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'triton_for_fused_1', 'backend_hash': '63937d058519033f995f0585a4aab6c8c8898fe6839dd14ce1536da9b902b160', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, out_ptr1, out_ptr2, out_ptr3, out_ptr5, out_ptr6, out_ptr7, out_ptr9, out_ptr10, out_ptr11):
    xpid = tl.program_id(0)
    XBLOCK: tl.constexpr = 1024
    if xpid >= 0 and xpid < 1:
        xpid_offset = xpid - 0
        xnumel = 20
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (x0), xmask)
        tmp3 = tl.load(in_ptr1 + (x0), xmask)
        tmp8 = tl.load(in_ptr2 + (x0), xmask)
        tmp13 = tl.load(in_ptr3 + (x0), xmask)
        tmp15 = tl.load(in_ptr4 + (0))
        tmp16 = tl.broadcast_to(tmp15, [XBLOCK])
        tmp1 = 0.999
        tmp2 = tmp0 * tmp1
        tmp4 = tmp3 * tmp3
        tmp5 = 0.0010000000000000009
        tmp6 = tmp4 * tmp5
        tmp7 = tmp2 + tmp6
        tmp9 = tmp3 - tmp8
        tmp10 = 0.09999999999999998
        tmp11 = tmp9 * tmp10
        tmp12 = tmp8 + tmp11
        tmp14 = libdevice.sqrt(tmp7)
        tmp17 = libdevice.pow(tmp1, tmp16)
        tmp18 = 1.0
        tmp19 = tmp17 - tmp18
        tmp20 = -tmp19
        tmp21 = libdevice.sqrt(tmp20)
        tmp22 = tmp14 / tmp21
        tmp23 = 1e-08
        tmp24 = tmp22 + tmp23
        tmp25 = 0.9
        tmp26 = libdevice.pow(tmp25, tmp16)
        tmp27 = tmp26 - tmp18
        tmp28 = 1000.0
        tmp29 = tmp27 * tmp28
        tmp30 = 1 / tmp29
        tmp31 = tmp24 / tmp30
        tmp32 = tmp12 / tmp31
        tmp33 = tmp13 + tmp32
        tl.store(out_ptr1 + (x0), tmp12, xmask)
        tl.store(out_ptr2 + (x0), tmp33, xmask)
        tl.store(out_ptr3 + (x0), tmp7, xmask)
    elif xpid >= 1 and xpid < 2:
        xpid_offset = xpid - 1
        xnumel = 20
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x1 = xindex
        tmp34 = tl.load(in_ptr5 + (x1), xmask)
        tmp37 = tl.load(in_ptr6 + (x1), xmask)
        tmp42 = tl.load(in_ptr7 + (x1), xmask)
        tmp47 = tl.load(in_ptr8 + (x1), xmask)
        tmp49 = tl.load(in_ptr9 + (0))
        tmp50 = tl.broadcast_to(tmp49, [XBLOCK])
        tmp35 = 0.999
        tmp36 = tmp34 * tmp35
        tmp38 = tmp37 * tmp37
        tmp39 = 0.0010000000000000009
        tmp40 = tmp38 * tmp39
        tmp41 = tmp36 + tmp40
        tmp43 = tmp37 - tmp42
        tmp44 = 0.09999999999999998
        tmp45 = tmp43 * tmp44
        tmp46 = tmp42 + tmp45
        tmp48 = libdevice.sqrt(tmp41)
        tmp51 = libdevice.pow(tmp35, tmp50)
        tmp52 = 1.0
        tmp53 = tmp51 - tmp52
        tmp54 = -tmp53
        tmp55 = libdevice.sqrt(tmp54)
        tmp56 = tmp48 / tmp55
        tmp57 = 1e-08
        tmp58 = tmp56 + tmp57
        tmp59 = 0.9
        tmp60 = libdevice.pow(tmp59, tmp50)
        tmp61 = tmp60 - tmp52
        tmp62 = 1000.0
        tmp63 = tmp61 * tmp62
        tmp64 = 1 / tmp63
        tmp65 = tmp58 / tmp64
        tmp66 = tmp46 / tmp65
        tmp67 = tmp47 + tmp66
        tl.store(out_ptr5 + (x1), tmp46, xmask)
        tl.store(out_ptr6 + (x1), tmp67, xmask)
        tl.store(out_ptr7 + (x1), tmp41, xmask)
    elif xpid >= 2 and xpid < 3:
        xpid_offset = xpid - 2
        xnumel = 20
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x2 = xindex
        tmp68 = tl.load(in_ptr10 + (x2), xmask)
        tmp71 = tl.load(in_ptr11 + (x2), xmask)
        tmp76 = tl.load(in_ptr12 + (x2), xmask)
        tmp81 = tl.load(in_ptr13 + (x2), xmask)
        tmp83 = tl.load(in_ptr14 + (0))
        tmp84 = tl.broadcast_to(tmp83, [XBLOCK])
        tmp69 = 0.999
        tmp70 = tmp68 * tmp69
        tmp72 = tmp71 * tmp71
        tmp73 = 0.0010000000000000009
        tmp74 = tmp72 * tmp73
        tmp75 = tmp70 + tmp74
        tmp77 = tmp71 - tmp76
        tmp78 = 0.09999999999999998
        tmp79 = tmp77 * tmp78
        tmp80 = tmp76 + tmp79
        tmp82 = libdevice.sqrt(tmp75)
        tmp85 = libdevice.pow(tmp69, tmp84)
        tmp86 = 1.0
        tmp87 = tmp85 - tmp86
        tmp88 = -tmp87
        tmp89 = libdevice.sqrt(tmp88)
        tmp90 = tmp82 / tmp89
        tmp91 = 1e-08
        tmp92 = tmp90 + tmp91
        tmp93 = 0.9
        tmp94 = libdevice.pow(tmp93, tmp84)
        tmp95 = tmp94 - tmp86
        tmp96 = 1000.0
        tmp97 = tmp95 * tmp96
        tmp98 = 1 / tmp97
        tmp99 = tmp92 / tmp98
        tmp100 = tmp80 / tmp99
        tmp101 = tmp81 + tmp100
        tl.store(out_ptr9 + (x2), tmp80, xmask)
        tl.store(out_ptr10 + (x2), tmp101, xmask)
        tl.store(out_ptr11 + (x2), tmp75, xmask)
    else:
        pass
''', device_str='cuda')

These kernels are meant to fully utilize memory bandwidth, so it assigns pieces of each input to different program IDs to be processed in parallel, utilizing all memory ports to load as many operands as possible simultaneously. (That’s what the large sequence of if statements is doing)

In addition to this codegen, we additionally utilize either cudagraphs or the cpp wrapper to make this faster. In the full output, you can see the wrapper code that inductor generates to call the generated triton kernels. This python code when benchmarked against APEX can be slow when there are lots of arguments to these foreach kernels, which is the case for many models. Cudagraphs allows us to remove this overhead by simply replaying the generated triton kernels. The cpp wrapper allows us to reduce this overhead by in essence generating a c++ version of this wrapper code.

1 Like