```py
@torch.compile(dynamic=True)
def fn(x):
return torch.max(x, -1)
```
ge…nerates the following code:
```py
@triton.jit
def triton_red_fused_max_0(in_ptr0, out_ptr0, out_ptr1, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
_tmp4 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
_tmp4_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), xmask & r0_mask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = triton_helpers.maximum(_tmp2, tmp1)
_tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
_tmp4_next, _tmp4_index_next = triton_helpers.maximum_with_index(
_tmp4, _tmp4_index, tmp1, rindex
)
_tmp4 = tl.where(r0_mask & xmask, _tmp4_next, _tmp4)
_tmp4_index = tl.where(r0_mask & xmask, _tmp4_index_next, _tmp4_index)
tmp2 = triton_helpers.max2(_tmp2, 1)[:, None]
tmp4_val, tmp4_idx = triton_helpers.max_with_index(_tmp4, _tmp4_index, 1)
tmp4 = tmp4_idx[:, None]
tl.store(out_ptr0 + (x0), tmp2, xmask)
tl.store(out_ptr1 + (x0), tmp4, xmask)
```
This could could be improved by doing:
```diff
diff --git a/out.py b/out.py
index 5d0acd594f7..5c3879867ed 100644
--- a/out.py
+++ b/out.py
@@ -8,7 +8,6 @@ def triton_red_fused_max_0(in_ptr0, out_ptr0, out_ptr1, ks0, xnumel, r0_numel, X
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
- _tmp2 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
_tmp4 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
_tmp4_index = tl.full([XBLOCK, R0_BLOCK], 9223372036854775807, tl.int64)
for r0_offset in range(0, r0_numel, R0_BLOCK):
@@ -19,15 +18,13 @@ def triton_red_fused_max_0(in_ptr0, out_ptr0, out_ptr1, ks0, xnumel, r0_numel, X
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + ks0*x0), xmask & r0_mask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
- tmp3 = triton_helpers.maximum(_tmp2, tmp1)
- _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
_tmp4_next, _tmp4_index_next = triton_helpers.maximum_with_index(
_tmp4, _tmp4_index, tmp1, rindex
)
_tmp4 = tl.where(r0_mask & xmask, _tmp4_next, _tmp4)
_tmp4_index = tl.where(r0_mask & xmask, _tmp4_index_next, _tmp4_index)
- tmp2 = triton_helpers.max2(_tmp2, 1)[:, None]
tmp4_val, tmp4_idx = triton_helpers.max_with_index(_tmp4, _tmp4_index, 1)
tmp4 = tmp4_idx[:, None]
+ tmp2 = tmp4_val[:, None]
tl.store(out_ptr0 + (x0), tmp2, xmask)
tl.store(out_ptr1 + (x0), tmp4, xmask)
```
because the `argmax` already compute the `amax`, so we don't need a separate reduction.
We could either:
1) Have a single two-output reduction op that does both `amax+argmax`
2) Combining the two at codegen time using the reduction cache. (We could use a DeferredLine, to swap between `triton_helpers.max_with_index` and `triton_helpers.max2` based on if the output is used.)
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @desertfire @aakhundov