Note On torch.compile
Generated Code
I’ve been receiving a lot of questions on what exactly torch.compile
generates when compiling the optimizer. For context the post on foreach kernels contains the main mechanism through which inductor generates large horizontally fused kernels like those used in the optimizer. torch.compile
uses this codegen to compile the traced optimzier graph into two kernels, the first of which performs the increment on the step values, and the second of which is a fully vertically and horizontally fused optimizer kernel.
For example, consider the following optimizer test code code:
import torch
params = [torch.rand(4, 5, device="cuda") for _ in range(3)]
kwargs = {"foreach": True, "capturable": True}
optimizer = torch.optim.Adam(params, **kwargs)
for p in params:
p.grad = torch.ones_like(p)
opt_step = torch.compile(optimizer.step)
opt_step()
running this with TORCH_LOGS="output_code"
will show the full code that inductor generates.
For the purposes of this note, the two kernels I described above are:
The increment kernel:
@triton_heuristics.foreach(
num_warps=8,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'kernel_name': 'triton_for_fused_0', 'backend_hash': '63937d058519033f995f0585a4aab6c8c8898fe6839dd14ce1536da9b902b160', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2):
xpid = tl.program_id(0)
XBLOCK: tl.constexpr = 1024
if xpid >= 0 and xpid < 1:
xpid_offset = xpid - 0
xnumel = 1
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.load(in_ptr0 + (0))
tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
tmp2 = 1.0
tmp3 = tmp1 + tmp2
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp3, None)
elif xpid >= 1 and xpid < 2:
xpid_offset = xpid - 1
xnumel = 1
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp4 = tl.load(in_ptr1 + (0))
tmp5 = tl.broadcast_to(tmp4, [XBLOCK])
tmp6 = 1.0
tmp7 = tmp5 + tmp6
tl.store(out_ptr1 + (tl.full([XBLOCK], 0, tl.int32)), tmp7, None)
elif xpid >= 2 and xpid < 3:
xpid_offset = xpid - 2
xnumel = 1
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp8 = tl.load(in_ptr2 + (0))
tmp9 = tl.broadcast_to(tmp8, [XBLOCK])
tmp10 = 1.0
tmp11 = tmp9 + tmp10
tl.store(out_ptr2 + (tl.full([XBLOCK], 0, tl.int32)), tmp11, None)
else:
pass
''', device_str='cuda')
And then the full optimizer kernel representing the main compute of Adam:
@triton_heuristics.foreach(
num_warps=8,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: '*fp32', 11: '*fp32', 12: '*fp32', 13: '*fp32', 14: '*fp32', 15: '*fp32', 16: '*fp32', 17: '*fp32', 18: '*fp32', 19: '*fp32', 20: '*fp32', 21: '*fp32', 22: '*fp32', 23: '*fp32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'kernel_name': 'triton_for_fused_1', 'backend_hash': '63937d058519033f995f0585a4aab6c8c8898fe6839dd14ce1536da9b902b160', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, out_ptr1, out_ptr2, out_ptr3, out_ptr5, out_ptr6, out_ptr7, out_ptr9, out_ptr10, out_ptr11):
xpid = tl.program_id(0)
XBLOCK: tl.constexpr = 1024
if xpid >= 0 and xpid < 1:
xpid_offset = xpid - 0
xnumel = 20
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp3 = tl.load(in_ptr1 + (x0), xmask)
tmp8 = tl.load(in_ptr2 + (x0), xmask)
tmp13 = tl.load(in_ptr3 + (x0), xmask)
tmp15 = tl.load(in_ptr4 + (0))
tmp16 = tl.broadcast_to(tmp15, [XBLOCK])
tmp1 = 0.999
tmp2 = tmp0 * tmp1
tmp4 = tmp3 * tmp3
tmp5 = 0.0010000000000000009
tmp6 = tmp4 * tmp5
tmp7 = tmp2 + tmp6
tmp9 = tmp3 - tmp8
tmp10 = 0.09999999999999998
tmp11 = tmp9 * tmp10
tmp12 = tmp8 + tmp11
tmp14 = libdevice.sqrt(tmp7)
tmp17 = libdevice.pow(tmp1, tmp16)
tmp18 = 1.0
tmp19 = tmp17 - tmp18
tmp20 = -tmp19
tmp21 = libdevice.sqrt(tmp20)
tmp22 = tmp14 / tmp21
tmp23 = 1e-08
tmp24 = tmp22 + tmp23
tmp25 = 0.9
tmp26 = libdevice.pow(tmp25, tmp16)
tmp27 = tmp26 - tmp18
tmp28 = 1000.0
tmp29 = tmp27 * tmp28
tmp30 = 1 / tmp29
tmp31 = tmp24 / tmp30
tmp32 = tmp12 / tmp31
tmp33 = tmp13 + tmp32
tl.store(out_ptr1 + (x0), tmp12, xmask)
tl.store(out_ptr2 + (x0), tmp33, xmask)
tl.store(out_ptr3 + (x0), tmp7, xmask)
elif xpid >= 1 and xpid < 2:
xpid_offset = xpid - 1
xnumel = 20
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = xindex
tmp34 = tl.load(in_ptr5 + (x1), xmask)
tmp37 = tl.load(in_ptr6 + (x1), xmask)
tmp42 = tl.load(in_ptr7 + (x1), xmask)
tmp47 = tl.load(in_ptr8 + (x1), xmask)
tmp49 = tl.load(in_ptr9 + (0))
tmp50 = tl.broadcast_to(tmp49, [XBLOCK])
tmp35 = 0.999
tmp36 = tmp34 * tmp35
tmp38 = tmp37 * tmp37
tmp39 = 0.0010000000000000009
tmp40 = tmp38 * tmp39
tmp41 = tmp36 + tmp40
tmp43 = tmp37 - tmp42
tmp44 = 0.09999999999999998
tmp45 = tmp43 * tmp44
tmp46 = tmp42 + tmp45
tmp48 = libdevice.sqrt(tmp41)
tmp51 = libdevice.pow(tmp35, tmp50)
tmp52 = 1.0
tmp53 = tmp51 - tmp52
tmp54 = -tmp53
tmp55 = libdevice.sqrt(tmp54)
tmp56 = tmp48 / tmp55
tmp57 = 1e-08
tmp58 = tmp56 + tmp57
tmp59 = 0.9
tmp60 = libdevice.pow(tmp59, tmp50)
tmp61 = tmp60 - tmp52
tmp62 = 1000.0
tmp63 = tmp61 * tmp62
tmp64 = 1 / tmp63
tmp65 = tmp58 / tmp64
tmp66 = tmp46 / tmp65
tmp67 = tmp47 + tmp66
tl.store(out_ptr5 + (x1), tmp46, xmask)
tl.store(out_ptr6 + (x1), tmp67, xmask)
tl.store(out_ptr7 + (x1), tmp41, xmask)
elif xpid >= 2 and xpid < 3:
xpid_offset = xpid - 2
xnumel = 20
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
tmp68 = tl.load(in_ptr10 + (x2), xmask)
tmp71 = tl.load(in_ptr11 + (x2), xmask)
tmp76 = tl.load(in_ptr12 + (x2), xmask)
tmp81 = tl.load(in_ptr13 + (x2), xmask)
tmp83 = tl.load(in_ptr14 + (0))
tmp84 = tl.broadcast_to(tmp83, [XBLOCK])
tmp69 = 0.999
tmp70 = tmp68 * tmp69
tmp72 = tmp71 * tmp71
tmp73 = 0.0010000000000000009
tmp74 = tmp72 * tmp73
tmp75 = tmp70 + tmp74
tmp77 = tmp71 - tmp76
tmp78 = 0.09999999999999998
tmp79 = tmp77 * tmp78
tmp80 = tmp76 + tmp79
tmp82 = libdevice.sqrt(tmp75)
tmp85 = libdevice.pow(tmp69, tmp84)
tmp86 = 1.0
tmp87 = tmp85 - tmp86
tmp88 = -tmp87
tmp89 = libdevice.sqrt(tmp88)
tmp90 = tmp82 / tmp89
tmp91 = 1e-08
tmp92 = tmp90 + tmp91
tmp93 = 0.9
tmp94 = libdevice.pow(tmp93, tmp84)
tmp95 = tmp94 - tmp86
tmp96 = 1000.0
tmp97 = tmp95 * tmp96
tmp98 = 1 / tmp97
tmp99 = tmp92 / tmp98
tmp100 = tmp80 / tmp99
tmp101 = tmp81 + tmp100
tl.store(out_ptr9 + (x2), tmp80, xmask)
tl.store(out_ptr10 + (x2), tmp101, xmask)
tl.store(out_ptr11 + (x2), tmp75, xmask)
else:
pass
''', device_str='cuda')
These kernels are meant to fully utilize memory bandwidth, so it assigns pieces of each input to different program IDs to be processed in parallel, utilizing all memory ports to load as many operands as possible simultaneously. (That’s what the large sequence of if statements is doing)
In addition to this codegen, we additionally utilize either cudagraphs or the cpp wrapper to make this faster. In the full output, you can see the wrapper code that inductor generates to call the generated triton kernels. This python code when benchmarked against APEX can be slow when there are lots of arguments to these foreach kernels, which is the case for many models. Cudagraphs allows us to remove this overhead by simply replaying the generated triton kernels. The cpp wrapper allows us to reduce this overhead by in essence generating a c++ version of this wrapper code.