Performance Comparison between Torch.Compile and APEX optimizers

Note On torch.compile Generated Code

I’ve been receiving a lot of questions on what exactly torch.compile generates when compiling the optimizer. For context the post on foreach kernels contains the main mechanism through which inductor generates large horizontally fused kernels like those used in the optimizer. torch.compile uses this codegen to compile the traced optimzier graph into two kernels, the first of which performs the increment on the step values, and the second of which is a fully vertically and horizontally fused optimizer kernel.

For example, consider the following optimizer test code code:

import torch

params = [torch.rand(4, 5, device="cuda") for _ in range(3)]
kwargs = {"foreach": True, "capturable": True}
optimizer = torch.optim.Adam(params, **kwargs)

for p in params:
    p.grad = torch.ones_like(p)

opt_step = torch.compile(optimizer.step)

opt_step()

running this with TORCH_LOGS="output_code" will show the full code that inductor generates.

For the purposes of this note, the two kernels I described above are:

The increment kernel:

@triton_heuristics.foreach(
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'triton_for_fused_0', 'backend_hash': '63937d058519033f995f0585a4aab6c8c8898fe6839dd14ce1536da9b902b160', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2):
    xpid = tl.program_id(0)
    XBLOCK: tl.constexpr = 1024
    if xpid >= 0 and xpid < 1:
        xpid_offset = xpid - 0
        xnumel = 1
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        tmp0 = tl.load(in_ptr0 + (0))
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
        tmp2 = 1.0
        tmp3 = tmp1 + tmp2
        tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp3, None)
    elif xpid >= 1 and xpid < 2:
        xpid_offset = xpid - 1
        xnumel = 1
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        tmp4 = tl.load(in_ptr1 + (0))
        tmp5 = tl.broadcast_to(tmp4, [XBLOCK])
        tmp6 = 1.0
        tmp7 = tmp5 + tmp6
        tl.store(out_ptr1 + (tl.full([XBLOCK], 0, tl.int32)), tmp7, None)
    elif xpid >= 2 and xpid < 3:
        xpid_offset = xpid - 2
        xnumel = 1
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        tmp8 = tl.load(in_ptr2 + (0))
        tmp9 = tl.broadcast_to(tmp8, [XBLOCK])
        tmp10 = 1.0
        tmp11 = tmp9 + tmp10
        tl.store(out_ptr2 + (tl.full([XBLOCK], 0, tl.int32)), tmp11, None)
    else:
        pass
''', device_str='cuda')

And then the full optimizer kernel representing the main compute of Adam:

@triton_heuristics.foreach(
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: '*fp32', 11: '*fp32', 12: '*fp32', 13: '*fp32', 14: '*fp32', 15: '*fp32', 16: '*fp32', 17: '*fp32', 18: '*fp32', 19: '*fp32', 20: '*fp32', 21: '*fp32', 22: '*fp32', 23: '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'triton_for_fused_1', 'backend_hash': '63937d058519033f995f0585a4aab6c8c8898fe6839dd14ce1536da9b902b160', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, out_ptr1, out_ptr2, out_ptr3, out_ptr5, out_ptr6, out_ptr7, out_ptr9, out_ptr10, out_ptr11):
    xpid = tl.program_id(0)
    XBLOCK: tl.constexpr = 1024
    if xpid >= 0 and xpid < 1:
        xpid_offset = xpid - 0
        xnumel = 20
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (x0), xmask)
        tmp3 = tl.load(in_ptr1 + (x0), xmask)
        tmp8 = tl.load(in_ptr2 + (x0), xmask)
        tmp13 = tl.load(in_ptr3 + (x0), xmask)
        tmp15 = tl.load(in_ptr4 + (0))
        tmp16 = tl.broadcast_to(tmp15, [XBLOCK])
        tmp1 = 0.999
        tmp2 = tmp0 * tmp1
        tmp4 = tmp3 * tmp3
        tmp5 = 0.0010000000000000009
        tmp6 = tmp4 * tmp5
        tmp7 = tmp2 + tmp6
        tmp9 = tmp3 - tmp8
        tmp10 = 0.09999999999999998
        tmp11 = tmp9 * tmp10
        tmp12 = tmp8 + tmp11
        tmp14 = libdevice.sqrt(tmp7)
        tmp17 = libdevice.pow(tmp1, tmp16)
        tmp18 = 1.0
        tmp19 = tmp17 - tmp18
        tmp20 = -tmp19
        tmp21 = libdevice.sqrt(tmp20)
        tmp22 = tmp14 / tmp21
        tmp23 = 1e-08
        tmp24 = tmp22 + tmp23
        tmp25 = 0.9
        tmp26 = libdevice.pow(tmp25, tmp16)
        tmp27 = tmp26 - tmp18
        tmp28 = 1000.0
        tmp29 = tmp27 * tmp28
        tmp30 = 1 / tmp29
        tmp31 = tmp24 / tmp30
        tmp32 = tmp12 / tmp31
        tmp33 = tmp13 + tmp32
        tl.store(out_ptr1 + (x0), tmp12, xmask)
        tl.store(out_ptr2 + (x0), tmp33, xmask)
        tl.store(out_ptr3 + (x0), tmp7, xmask)
    elif xpid >= 1 and xpid < 2:
        xpid_offset = xpid - 1
        xnumel = 20
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x1 = xindex
        tmp34 = tl.load(in_ptr5 + (x1), xmask)
        tmp37 = tl.load(in_ptr6 + (x1), xmask)
        tmp42 = tl.load(in_ptr7 + (x1), xmask)
        tmp47 = tl.load(in_ptr8 + (x1), xmask)
        tmp49 = tl.load(in_ptr9 + (0))
        tmp50 = tl.broadcast_to(tmp49, [XBLOCK])
        tmp35 = 0.999
        tmp36 = tmp34 * tmp35
        tmp38 = tmp37 * tmp37
        tmp39 = 0.0010000000000000009
        tmp40 = tmp38 * tmp39
        tmp41 = tmp36 + tmp40
        tmp43 = tmp37 - tmp42
        tmp44 = 0.09999999999999998
        tmp45 = tmp43 * tmp44
        tmp46 = tmp42 + tmp45
        tmp48 = libdevice.sqrt(tmp41)
        tmp51 = libdevice.pow(tmp35, tmp50)
        tmp52 = 1.0
        tmp53 = tmp51 - tmp52
        tmp54 = -tmp53
        tmp55 = libdevice.sqrt(tmp54)
        tmp56 = tmp48 / tmp55
        tmp57 = 1e-08
        tmp58 = tmp56 + tmp57
        tmp59 = 0.9
        tmp60 = libdevice.pow(tmp59, tmp50)
        tmp61 = tmp60 - tmp52
        tmp62 = 1000.0
        tmp63 = tmp61 * tmp62
        tmp64 = 1 / tmp63
        tmp65 = tmp58 / tmp64
        tmp66 = tmp46 / tmp65
        tmp67 = tmp47 + tmp66
        tl.store(out_ptr5 + (x1), tmp46, xmask)
        tl.store(out_ptr6 + (x1), tmp67, xmask)
        tl.store(out_ptr7 + (x1), tmp41, xmask)
    elif xpid >= 2 and xpid < 3:
        xpid_offset = xpid - 2
        xnumel = 20
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x2 = xindex
        tmp68 = tl.load(in_ptr10 + (x2), xmask)
        tmp71 = tl.load(in_ptr11 + (x2), xmask)
        tmp76 = tl.load(in_ptr12 + (x2), xmask)
        tmp81 = tl.load(in_ptr13 + (x2), xmask)
        tmp83 = tl.load(in_ptr14 + (0))
        tmp84 = tl.broadcast_to(tmp83, [XBLOCK])
        tmp69 = 0.999
        tmp70 = tmp68 * tmp69
        tmp72 = tmp71 * tmp71
        tmp73 = 0.0010000000000000009
        tmp74 = tmp72 * tmp73
        tmp75 = tmp70 + tmp74
        tmp77 = tmp71 - tmp76
        tmp78 = 0.09999999999999998
        tmp79 = tmp77 * tmp78
        tmp80 = tmp76 + tmp79
        tmp82 = libdevice.sqrt(tmp75)
        tmp85 = libdevice.pow(tmp69, tmp84)
        tmp86 = 1.0
        tmp87 = tmp85 - tmp86
        tmp88 = -tmp87
        tmp89 = libdevice.sqrt(tmp88)
        tmp90 = tmp82 / tmp89
        tmp91 = 1e-08
        tmp92 = tmp90 + tmp91
        tmp93 = 0.9
        tmp94 = libdevice.pow(tmp93, tmp84)
        tmp95 = tmp94 - tmp86
        tmp96 = 1000.0
        tmp97 = tmp95 * tmp96
        tmp98 = 1 / tmp97
        tmp99 = tmp92 / tmp98
        tmp100 = tmp80 / tmp99
        tmp101 = tmp81 + tmp100
        tl.store(out_ptr9 + (x2), tmp80, xmask)
        tl.store(out_ptr10 + (x2), tmp101, xmask)
        tl.store(out_ptr11 + (x2), tmp75, xmask)
    else:
        pass
''', device_str='cuda')

These kernels are meant to fully utilize memory bandwidth, so it assigns pieces of each input to different program IDs to be processed in parallel, utilizing all memory ports to load as many operands as possible simultaneously. (That’s what the large sequence of if statements is doing)

In addition to this codegen, we additionally utilize either cudagraphs or the cpp wrapper to make this faster. In the full output, you can see the wrapper code that inductor generates to call the generated triton kernels. This python code when benchmarked against APEX can be slow when there are lots of arguments to these foreach kernels, which is the case for many models. Cudagraphs allows us to remove this overhead by simply replaying the generated triton kernels. The cpp wrapper allows us to reduce this overhead by in essence generating a c++ version of this wrapper code.

1 Like