TorchInductor: a PyTorch-native Compiler with Define-by-Run IR and Symbolic Shapes

Hi @jansel , I wonder why inductor chooses Triton to generate CUDA kernels instead of other solutions like TVM / XLA?

@void-main I believe this question was answered earlier in this same thread.

1 Like

Ah, my bad, missed the earlier discussion. Thanks for point that out @Lezcano !

So, if I understand correctly, the key point to not choose TVM is that Tensor IR requires more expert knowledge than Triton to get a good performance?

It seems the key point to choose triton is that it is focused on nvidia GPU optimizations and others(TVM/XLA) are not GPU bounded.

After digging on pytorch’s matmul triton template, I think it is rather genalized not bound to gpu. Hardware vendor can still port with triton and do their own transforms with this “tiled” language.
However, pytorch’s inductor implementation is indeed rather bound to gpu, which makes it harder to seperate the logic for inductor’s original role with it’s call to cuda apis.

2 Likes

Does the fact that dynamo is only for Linux and Mac mean that contributing to inductor is not possible on Windows?

@Idank96 I’d expect dynamo/inductor to work with Windows Subsystem for Linux (WSL). Though I don’t know anyone who has tried that.

One easy way to contribute on Windows would be to try that, fix any bugs (if any), and write instructions that other Windows users could follow.

If you want to try to add Windows support without WSL, we also welcome pull requests to add support for Windows.

It works on WSL, yes. I’ve been using the pytorch nightly build and it works fine.

@jansel What is unclear to me is how PT2.0 uses CUDA Graphs in torch.compile — the engineering details mainly.

I understand the flow:
PyTorch → TorchDynamo → TorchInductor → Triton → NVIDIA GPU

With the TORCH_COMPILE_DEBUG=1 env variable exported, I get the following few messages at the end of the logs:

...
...
[2024-03-25 11:07:02,184] torch._inductor.cudagraph_trees: [INFO] recording cudagraph tree for None
[2024-03-25 11:07:02,299] torch._inductor.cudagraph_trees: [DEBUG] Running warmup of function 0
[2024-03-25 11:07:02,462] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 0b0f632ae495eb55ca41fe06a33f3ed1
{'cassette': 0.9992951154708862, 'tape player': 0.0005148712079972029, 'cassette player': 0.0001740186125971377, 'radio': 1.0334849321225192e-05, 'CD player': 4.437319603312062e-06}
[2024-03-25 11:07:02,463] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* _splitext             /usr/lib/python3.10/genericpath.py 121
[2024-03-25 11:07:02,462] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 0b0f632ae495eb55ca41fe06a33f3ed1
[2024-03-25 11:07:02,477] torch._inductor.cudagraph_trees: [DEBUG] Recording function 0 of graph recording id 0
[2024-03-25 11:07:02,585] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 0b0f632ae495eb55ca41fe06a33f3ed1
{'cassette': 0.9992951154708862, 'tape player': 0.0005148712079972029, 'cassette player': 0.0001740186125971377, 'radio': 1.0334849321225192e-05, 'CD player': 4.437319603312062e-06}
[2024-03-25 11:07:02,586] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 0b0f632ae495eb55ca41fe06a33f3ed1
[2024-03-25 11:07:02,604] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 0b0f632ae495eb55ca41fe06a33f3ed1
{'cassette': 0.9992951154708862, 'tape player': 0.0005148712079972029, 'cassette player': 0.0001740186125971377, 'radio': 1.0334849321225192e-05, 'CD player': 4.437319603312062e-06}
...
...

I called my inference three times. From the documentation about CUDAGraph Trees, it seems the first call warms-up, the second call records the CUDA Graph and the third call makes use of the captured CUDA Graph.

My question is, in the above flow, where does CUDAGraph fit in? And who drives it?

PyTorch → TorchDynamo → TorchInductor → Triton → CUDAGraph Capture (who drives it and the logic behind what to capture and what not to)→ NVIDIA GPU

I suppose a flow as the above, but again, the question of who drives the CUDAGraph Trees is a question. Is there any documentation regarding this? (Documentation helps before looking into the code, say torch/_inductor/cudagraph_trees.py, lest I misinterpret anything)

cc @eellison who should be able to add more details.

At compile time it is implemented as:

  • Compiler passes that analyize the model to see if cudagraphs will work (cudagraphs doesn’t support many things that must be checked for, mainly any work being done on the CPU)
  • If it can be used, codegen a runtime wrapper around the model that applies cudagraphs

I was going through the inductor code base and found the following:

The above code inside compile_fx_inner uses the cudagraphify method, which ultimately (using CUDAGraph Trees) tries to make a CUDAGraph out of the compiled_graph.current_callable.

As per the signature of cudagraphify the first argument is:

So, compiled_graph.current_callable is of type torch.fx.GraphModule.

Now my question is, what is this compiled_graph.current_callable: torch.fx.GraphModule?

In the compile_fx_inner function, I find compiled_graph defined as above.

Now gm is again a torch.fx.GraphModule.

So, questions here:

  1. What is the torch.fx.GraphModule instance that compile_fx_inner takes as input? Is it a fx_graph module passed down from TorchDynamo?
  2. What is the torch.fx.GraphModule instance that is passed to cudagraphify in compile_fx_inner? Is it an optimized GraphModule with triton kernels built into it?
  3. Next, is my following understanding correct:
    We parse PyTorch code to produce many FX Graphs (many because we might have graph breaks). These FX graphs from PyTorch code are GraphModules(), and TorchInductor makes triton code out of each FX Graph. So, in TorchInductor, we have a set of GraphModules (compiled to use triton), and then for each of these GraphModules, we decide whether to use CUDA Graphs.

Somewhat related @ezyang recently published a podcast on cudagraph trees:

I believe the

model: torch.fx.GraphModule

type annotation is wrong, it should be something like:

model: Callable[[List[Tensor]], List[Tensor]]

it is the compile wrapper code generated by TorchInductor that launches all the Triton kernels.

Is there any debug environment variable which can be set to have a glance at the wrapper code?

If we take the small example:

import os
os.environ["TORCH_COMPILE_DEBUG"] = "1"
import torch
@torch.compile(mode="reduce-overhead")
def toy_example(a, b):
    return a * b

tensor1 = torch.randn(10, device="cuda")
tensor2 = torch.randn(10, device="cuda")

toy_example(tensor1, tensor2)
toy_example(tensor1, tensor2)
toy_example(tensor1, tensor2)

I get the following debug log:

[2024-03-25 22:31:29,930] torch._dynamo.eval_frame: [DEBUG] Saving dynamo config and hash for new compiled object(s). Hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:29,941] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:29,944] [4/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing toy_example /tmp/ipykernel_594021/431827302.py:2
[2024-03-25 22:31:29,944] [4/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2024-03-25 22:31:29,945] [4/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /tmp/ipykernel_594021/431827302.py:2 in toy_example
[2024-03-25 22:31:29,945] [4/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]     @torch.compile(mode="reduce-overhead")
[2024-03-25 22:31:29,946] [4/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /tmp/ipykernel_594021/431827302.py:4 in toy_example
[2024-03-25 22:31:29,946] [4/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         return a * b
[2024-03-25 22:31:29,946] [4/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST a []
[2024-03-25 22:31:29,946] [4/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST b [LazyVariableTracker()]
[2024-03-25 22:31:29,947] [4/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE BINARY_MULTIPLY None [LazyVariableTracker(), LazyVariableTracker()]
[2024-03-25 22:31:29,947] [4/0] torch._dynamo.output_graph: [DEBUG] create_graph_input L_a_ L['a']
[2024-03-25 22:31:29,948] [4/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['a'] (10,) [<DimDynamic.STATIC: 2>] [None]
[2024-03-25 22:31:29,949] [4/0] torch._dynamo.output_graph: [DEBUG] create_graph_input L_b_ L['b']
[2024-03-25 22:31:29,950] [4/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['b'] (10,) [<DimDynamic.STATIC: 2>] [None]
[2024-03-25 22:31:29,952] [4/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RETURN_VALUE None [TensorVariable()]
[2024-03-25 22:31:29,952] [4/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing toy_example (RETURN_VALUE)
[2024-03-25 22:31:29,953] [4/0] torch._dynamo.symbolic_convert: [DEBUG] RETURN_VALUE triggered compile
[2024-03-25 22:31:29,953] [4/0] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /tmp/ipykernel_594021/431827302.py, line 4 in toy_example>], graph_break=False)
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  ===== __compiled_fn_7 =====
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  <eval_with_key>.67 class GraphModule(torch.nn.Module):
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]     def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         l_a_ = L_a_
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         l_b_ = L_b_
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: /tmp/ipykernel_594021/431827302.py:4, code: return a * b
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         mul = l_a_ * l_b_;  l_a_ = l_b_ = None
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         return (mul,)
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG] 
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG]  __compiled_fn_7 <eval_with_key>.67 opcode         name    target                   args          kwargs
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] -------------  ------  -----------------------  ------------  --------
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder    l_a_    L_a_                     ()            {}
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder    l_b_    L_b_                     ()            {}
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] call_function  mul     <built-in function mul>  (l_a_, l_b_)  {}
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] output         output  output                   ((mul,),)     {}
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] 
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] TRACED GRAPH TENSOR SIZES
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] ===== __compiled_fn_7 =====
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] l_a_: (10,)
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] l_b_: (10,)
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] mul: (10,)
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] 
[2024-03-25 22:31:29,957] [4/0] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO] TRACED GRAPH
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO]  ===== Forward graph 7 =====
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO]  <eval_with_key>.71 from /home/abhishek/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:511 in wrapped class <lambda>(torch.nn.Module):
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO]     def forward(self, arg0_1: "f32[10]", arg1_1: "f32[10]"):
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO]         # File: /tmp/ipykernel_594021/431827302.py:4, code: return a * b
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO]         mul: "f32[10]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO]         return (mul,)
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO]         
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO] 
[2024-03-25 22:31:29,968] [4/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 4
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO] TRACED GRAPH
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO]  ===== AFTER POST GRAD =====
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO]  <eval_with_key>.72 from /home/abhishek/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:511 in wrapped class <lambda>(torch.nn.Module):
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO]     def forward(self, arg0_1: "f32[10]", arg1_1: "f32[10]"):
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO]         # File: /tmp/ipykernel_594021/431827302.py:4, code: return a * b
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO]         mul: "f32[10]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO]         return (mul,)
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO]         
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO] 
[2024-03-25 22:31:29,974] [4/0] torch._inductor.graph: [DEBUG] lowering %arg0_1 : [num_users=1] = placeholder[target=arg0_1] 
[2024-03-25 22:31:29,974] [4/0] torch._inductor.graph: [DEBUG] lowering %arg1_1 : [num_users=1] = placeholder[target=arg1_1] 
[2024-03-25 22:31:29,975] [4/0] torch._inductor.graph: [DEBUG] lowering %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, %arg1_1), kwargs = {}) 
[2024-03-25 22:31:29,976] [4/0] torch._inductor.graph: [DEBUG]   via <function mul at 0x7f952b69ba30>
[2024-03-25 22:31:29,976] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] eval True == True [statically known]
[2024-03-25 22:31:29,977] [4/0] torch._inductor.graph: [DEBUG] lowering return (mul,) 
[2024-03-25 22:31:29,978] [4/0] torch._inductor.graph: [DEBUG] Force channels last inputs for 0 conv for the current graph with id 4
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] scheduling ComputedBuffer(name='buf0', layout=FixedLayout('cuda', torch.float32, size=[10], stride=[1]), data=Pointwise(
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]   'cuda',
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]   torch.float32,
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]   def inner_fn(index):
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]       i0 = index
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]       tmp0 = ops.load(arg0_1, i0)
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]       tmp1 = ops.load(arg1_1, i0)
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]       tmp2 = tmp0 * tmp1
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]       return tmp2
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]   ,
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]   ranges=[10],
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]   origin_node=mul,
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG]   origins={mul}
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] ))
[2024-03-25 22:31:29,982] [4/0] torch._inductor.scheduler: [DEBUG] scheduling output buf0
[2024-03-25 22:31:29,983] [4/0] torch._inductor.debug: [INFO] Writing debug ir to  /media/abhishek/Abhishek_NVMe/shweta_machine/trace_analysis/start-fc-gpu/fc-http-gpu-inference-torchhub-cv-mobilenet-v2/src/code/torch-compiled-reduce-overhead-dot-graphs/torch_compile_debug/run_2024_03_25_22_14_47_042588-pid_594021/torchinductor/model__7_inference_13.4/ir_pre_fusion.txt
[2024-03-25 22:31:29,985] [4/0] torch._inductor.debug: [INFO] Writing debug ir to  /media/abhishek/Abhishek_NVMe/shweta_machine/trace_analysis/start-fc-gpu/fc-http-gpu-inference-torchhub-cv-mobilenet-v2/src/code/torch-compiled-reduce-overhead-dot-graphs/torch_compile_debug/run_2024_03_25_22_14_47_042588-pid_594021/torchinductor/model__7_inference_13.4/ir_post_fusion.txt
[2024-03-25 22:31:30,006] [4/0] torch._inductor.scheduler: [DEBUG] Generating code for node buf0 with estimated runtime 0.267857
[2024-03-25 22:31:30,010] [4/0] torch._inductor.codegen.triton: [DEBUG] Generating kernel code with kernel_name: triton_poi_fused_mul_0
[2024-03-25 22:31:30,645] [4/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2024-03-25 22:31:30,646] [4/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 16, num_warps: 1, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False
[2024-03-25 22:31:30,978] [4/0] torch._inductor.graph: [DEBUG] Output code written to: /tmp/torchinductor_abhishek/6i/c6imvaznfo4qm5nsa3hyxxmktztn4risj3dr3los4gefjjf62p63.py
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] Output code: 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from ctypes import c_void_p, c_long
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import torch
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import math
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import random
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import os
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import tempfile
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from math import inf, nan
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.hooks import run_intermediate_hooks
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.utils import maybe_profile
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.codegen.memory_planning import _align as align
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch import device, empty, empty_strided
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.codecache import AsyncCompile
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.select_algorithm import extern_kernels
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] aten = torch.ops.aten
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] inductor_ops = torch.ops.inductor
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] alloc_from_pool = torch.ops.inductor._alloc_from_pool
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] async_compile = AsyncCompile()
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] # kernel path: /tmp/torchinductor_abhishek/xz/cxzhsm3ysi6dh7mibspq3kqhgzqiqdq5qebnkt52iphjqmhccl66.py
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] # Source Nodes: [mul], Original ATen: [aten.mul]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] # mul => mul
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] triton_poi_fused_mul_0 = async_compile.triton('triton_', '''
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import triton
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import triton.language as tl
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.ir import ReductionHint
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.ir import TileHint
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.triton_heuristics import AutotuneHint, pointwise
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.utils import instance_descriptor
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor import triton_helpers
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] @pointwise(
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     size_hints=[16], 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     filename=__file__,
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_0', 'mutated_arg_names': []},
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     min_elem_per_thread=0
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] )
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] @triton.jit
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     xnumel = 10
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     xoffset = tl.program_id(0) * XBLOCK
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     xmask = xindex < xnumel
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     x0 = xindex
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     tmp1 = tl.load(in_ptr1 + (x0), xmask)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     tmp2 = tmp0 * tmp1
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     tl.store(out_ptr0 + (x0), tmp2, xmask)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] ''')
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import triton
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import triton.language as tl
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.triton_heuristics import grid, start_graph, end_graph
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] async_compile.wait(globals())
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] del async_compile
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] def call(args):
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     arg0_1, arg1_1 = args
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     args.clear()
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     assert_size_stride(arg0_1, (10, ), (1, ))
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     assert_size_stride(arg1_1, (10, ), (1, ))
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     with torch.cuda._DeviceGuard(0):
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]         torch.cuda.set_device(0) # no-op to ensure context
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]         buf0 = empty((10, ), device='cuda', dtype=torch.float32)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]         # Source Nodes: [mul], Original ATen: [aten.mul]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]         stream0 = get_cuda_stream(0)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]         triton_poi_fused_mul_0.run(arg0_1, arg1_1, buf0, 10, grid=grid(10), stream=stream0)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]         del arg0_1
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]         del arg1_1
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]         return (buf0, )
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] def benchmark_compiled_module(times=10, repeat=10):
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     from torch._dynamo.testing import rand_strided
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     from torch._inductor.utils import print_performance
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     arg0_1 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.float32)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     arg1_1 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.float32)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     fn = lambda: call([arg0_1, arg1_1])
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     return print_performance(fn, times=times, repeat=repeat)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] if __name__ == "__main__":
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     from torch._inductor.wrapper_benchmark import compiled_module_main
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]     compiled_module_main('None', benchmark_compiled_module)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] 
[2024-03-25 22:31:30,980] [4/0] torch._inductor.graph.__output_code: [INFO] Output code written to: /tmp/torchinductor_abhishek/6i/c6imvaznfo4qm5nsa3hyxxmktztn4risj3dr3los4gefjjf62p63.py
[2024-03-25 22:31:30,981] [4/0] torch._inductor.compile_fx: [DEBUG] FX codegen and compilation took 1.013s
[2024-03-25 22:31:30,984] [4/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 4
[2024-03-25 22:31:30,984] [4/0] torch._inductor.debug: [WARNING] model__7_inference_13 debug trace: /tmp/torchinductor_abhishek/6i/c6imvaznfo4qm5nsa3hyxxmktztn4risj3dr3los4gefjjf62p63.debug
[2024-03-25 22:31:30,987] [4/0] torch._dynamo.output_graph: [INFO] Step 2: done compiler function inductor
[2024-03-25 22:31:30,988] [4/0] torch.fx.experimental.symbolic_shapes: [INFO] produce_guards
[2024-03-25 22:31:30,989] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['a'].size()[0] 10 None
[2024-03-25 22:31:30,989] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['a'].stride()[0] 1 None
[2024-03-25 22:31:30,989] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['a'].storage_offset() 0 None
[2024-03-25 22:31:30,990] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['b'].size()[0] 10 None
[2024-03-25 22:31:30,990] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['b'].stride()[0] 1 None
[2024-03-25 22:31:30,990] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['b'].storage_offset() 0 None
[2024-03-25 22:31:30,991] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['a'].size()[0] == 10
[2024-03-25 22:31:30,991] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['a'].stride()[0] == 1
[2024-03-25 22:31:30,991] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['a'].storage_offset() == 0
[2024-03-25 22:31:30,992] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['b'].size()[0] == 10
[2024-03-25 22:31:30,992] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['b'].stride()[0] == 1
[2024-03-25 22:31:30,993] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['b'].storage_offset() == 0
[2024-03-25 22:31:30,993] [4/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2024-03-25 22:31:30,995] [4/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['a'], '_dynamo_dynamic_indices') == False           # return a * b  # mp/ipykernel_594021/431827302.py:4 in toy_example
[2024-03-25 22:31:30,995] [4/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['b'], '_dynamo_dynamic_indices') == False           # return a * b  # mp/ipykernel_594021/431827302.py:4 in toy_example
[2024-03-25 22:31:30,996] [4/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:379 in init_ambient_guards
[2024-03-25 22:31:30,997] [4/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(140282244472576))  # _dynamo/output_graph.py:385 in init_ambient_guards
[2024-03-25 22:31:30,998] [4/0] torch._dynamo.guards.__guards: [DEBUG] ___compile_config_hash() == '9fab733732f45cfca18c83a96ae70468'  # _dynamo/output_graph.py:387 in init_ambient_guards
[2024-03-25 22:31:30,998] [4/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['a'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[10], stride=[1])  # return a * b  # mp/ipykernel_594021/431827302.py:4 in toy_example
[2024-03-25 22:31:30,999] [4/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['b'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[10], stride=[1])  # return a * b  # mp/ipykernel_594021/431827302.py:4 in toy_example
[2024-03-25 22:31:31,001] torch._inductor.cudagraph_trees: [INFO] recording cudagraph tree for None
[2024-03-25 22:31:31,082] torch._inductor.cudagraph_trees: [DEBUG] Running warmup of function 0
[2024-03-25 22:31:31,083] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:31,084] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:31,084] torch._inductor.cudagraph_trees: [DEBUG] Recording function 0 of graph recording id 0
[2024-03-25 22:31:31,157] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:31,157] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:31,158] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468

Is that wrapper code shown here? If so, which one is it?

Yes the wrapper code is this function:

There is also TORCH_LOGS=+output_code

1 Like

(end of the original post got cropped so posting the end as a reply to the original post)
I guess it is code that is dumped to the file mentioned: Output code written to: /tmp/torchinductor_abhishek/6i/c6imvaznfo4qm5nsa3hyxxmktztn4risj3dr3los4gefjjf62p63.py

And it happens to be:


from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_abhishek/xz/cxzhsm3ysi6dh7mibspq3kqhgzqiqdq5qebnkt52iphjqmhccl66.py
# Source Nodes: [mul], Original ATen: [aten.mul]
# mul => mul
triton_poi_fused_mul_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers

@pointwise(
    size_hints=[16], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_0', 'mutated_arg_names': []},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 10
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 * tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)
''')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (10, ), (1, ))
    assert_size_stride(arg1_1, (10, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = empty((10, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [mul], Original ATen: [aten.mul]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_mul_0.run(arg0_1, arg1_1, buf0, 10, grid=grid(10), stream=stream0)
        del arg0_1
        del arg1_1
        return (buf0, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg1_1 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.float32)
    fn = lambda: call([arg0_1, arg1_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

And the function of interest is call in the above module, right (the compiled wrapper code generated by TorchInductor)? Based on:

Yes, cudagraphify is wrapping call

@jansel Was having trouble figuring out something. Suppose we have an output_code.py file generated automatically by setting the TORCH_COMPILE_DEBUG flag to 1. Now this output_code.py has say definition of x number of triton kernels and the call function calls each of these x kernels, but nsys profile of the output_code.py file as nsys profile -o python3 output_code.py, reports execution of say y number of kernels, where y < x.

For example with the following output_code.py file (torchinductor/model__4_inference_10.1/):


from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_abhishek/jr/cjrdra55i6oq3npg2ofcxhhxqozdch6g3eqzefwk7iyia2vnemp7.py
# Source Nodes: [tensor], Original ATen: [aten.lift_fresh]
# tensor => full_default
triton_poi_fused_lift_fresh_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[1], 
    filename=__file__,
    triton_meta={'signature': {0: '*i1', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {1: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=(1,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_lift_fresh_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    tmp0 = tl.full([1], False, tl.int1)
    tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp0, None)
''', device_str='cuda')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream


# kernel path: /tmp/torchinductor_abhishek/sa/csa63dnba3cwdups6w77k5ehbwe2srxwqsgnkaxjxvexrisqpofw.py
# Source Nodes: [arange], Original ATen: [aten.arange]
# arange => iota
triton_poi_fused_arange_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[16], 
    filename=__file__,
    triton_meta={'signature': {0: '*i64', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_arange_1', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 10
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = x0
    tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_abhishek/yq/cyqzktrq7yalywzrlmyzqcyugfkskjk6ateieldildbva3hwulja.py
# Source Nodes: [tensor_1], Original ATen: [aten.lift_fresh]
# tensor_1 => full_default_1
triton_poi_fused_lift_fresh_2 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[1], 
    filename=__file__,
    triton_meta={'signature': {0: '*i64', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {1: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=(1,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_lift_fresh_2', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    tmp0 = tl.full([1], 0, tl.int64)
    tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_abhishek/vk/cvk4asato2zcv7gno273d5iju2gu3zsd35tiy3m3xl2yamm5zrfe.py
# Source Nodes: [tensor_2], Original ATen: [aten.lift_fresh]
# tensor_2 => full_default_2
triton_poi_fused_lift_fresh_3 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[1], 
    filename=__file__,
    triton_meta={'signature': {0: '*i64', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {1: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=(1,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_lift_fresh_3', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    tmp0 = tl.full([1], 1, tl.int64)
    tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_abhishek/uy/cuybdehcpe364i7q5qhx5o27nuicncddeea6lcknqr56otci23gt.py
# Source Nodes: [tensor_3], Original ATen: [aten.lift_fresh]
# tensor_3 => full_default_3
triton_poi_fused_lift_fresh_4 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[1], 
    filename=__file__,
    triton_meta={'signature': {0: '*i64', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {1: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=(1,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_lift_fresh_4', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    tmp0 = tl.full([1], 5, tl.int64)
    tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_abhishek/qr/cqrvjsoxqqbmo5xjphzpeur2avys3e4j7komqlq5ri3kjdjeqtyl.py
# Source Nodes: [tensor_4], Original ATen: [aten.lift_fresh]
# tensor_4 => full_default_4
triton_poi_fused_lift_fresh_5 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[1], 
    filename=__file__,
    triton_meta={'signature': {0: '*i1', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {1: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=(1,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_lift_fresh_5', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
    min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    tmp0 = tl.full([1], True, tl.int1)
    tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp0, None)
''', device_str='cuda')


async_compile.wait(globals())
del async_compile

def call(args):
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((1, ), (1, ), torch.bool)
        # Source Nodes: [tensor], Original ATen: [aten.lift_fresh]
        stream0 = get_raw_stream(0)
        triton_poi_fused_lift_fresh_0.run(buf0, 1, grid=grid(1), stream=stream0)
        buf1 = empty_strided_cuda((10, ), (1, ), torch.int64)
        # Source Nodes: [arange], Original ATen: [aten.arange]
        triton_poi_fused_arange_1.run(buf1, 10, grid=grid(10), stream=stream0)
        buf2 = empty_strided_cuda((), (), torch.int64)
        # Source Nodes: [tensor_1], Original ATen: [aten.lift_fresh]
        triton_poi_fused_lift_fresh_2.run(buf2, 1, grid=grid(1), stream=stream0)
        buf3 = empty_strided_cuda((), (), torch.int64)
        # Source Nodes: [tensor_2], Original ATen: [aten.lift_fresh]
        triton_poi_fused_lift_fresh_3.run(buf3, 1, grid=grid(1), stream=stream0)
        buf4 = empty_strided_cuda((), (), torch.int64)
        # Source Nodes: [tensor_3], Original ATen: [aten.lift_fresh]
        triton_poi_fused_lift_fresh_4.run(buf4, 1, grid=grid(1), stream=stream0)
        buf5 = empty_strided_cuda((), (), torch.bool)
        # Source Nodes: [tensor_4], Original ATen: [aten.lift_fresh]
        triton_poi_fused_lift_fresh_5.run(buf5, 1, grid=grid(1), stream=stream0)
        buf6 = empty_strided_cuda((), (), torch.bool)
        # Source Nodes: [tensor_5], Original ATen: [aten.lift_fresh]
        triton_poi_fused_lift_fresh_0.run(buf6, 1, grid=grid(1), stream=stream0)
    return (buf0, buf2, buf3, buf4, buf5, buf6, buf1, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    fn = lambda: call([])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

As we can see we have definitions of six triton kernels in the file. And the call function makes call to each of these six kernels. But a nsys profile reports the following:


 ** CUDA Kernel Launch & Exec Time Summary (cuda_kern_exec_sum):                                                                                               
                                                                                                                                                               
  PID    TID   DevId  Count  QCount  TAvg (ns)  TMed (ns)  TMin (ns)  TMax (ns)  TStdDev (ns)  AAvg (ns)  AMed (ns)  AMin (ns)  AMax (ns)  AStdDev (ns)  QAvg (ns)  QMed (ns)  QMin (ns)  QMax (ns)  QStdDev (ns)  KAvg (ns)  KMed (ns)  KMin (ns)  KMax (ns)  KStdDev (ns)     API Name     Kernel Name 
 -----  -----  -----  -----  ------  ---------  ---------  ---------  ---------  ------------  ---------  ---------  ---------  ---------  ------------  ---------  ---------  ---------  ---------  ------------  ---------  ---------  ---------  ---------  ------------  --------------  ------------
 6,093  6,093      0    600     559    5,935.2    5,614.0      5,244     20,308       1,314.0    3,956.5    3,442.0      3,088     18,173       1,546.9    1,002.9    1,013.0        457      1,635         126.7    1,119.4    1,120.0      1,087      1,760          34.1  cuLaunchKernel  triton__0d1c
 6,093  6,093      0    100      92    5,790.7    5,705.5      5,479      7,642         323.6    3,855.3    3,571.5      3,322      7,405         813.5      969.1      974.0        325      1,241         100.4    1,134.1    1,120.0      1,088      1,408          34.5  cuLaunchKernel  triton__0d1 
                                                                                                                                                               

Only two kernels are reported, triton_0d1c and triton_0d1.
Why only these two kernels? I expected six of them. And how are these kernels named? I just want to know the corresponding triton code of the kernel (running on the GPU as reported by nsys). Kind of a mapping.

Can you please explain, what is happening here?

That is weird, not sure what is going on.

What does the pytorch profiler say?

1 Like

@jansel I wrapped the code as follows to make use of PyTorch Profiler:

...
...
from torch.profiler import profile, record_function, ProfilerActivity

if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("model_inference"):
            compiled_module_main('None', benchmark_compiled_module)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

The output which is shown by PyTorch Profiler:

$ python3 output_code_pytorch_profiler.py 
STAGE:2024-04-13 12:42:11 23865:23865 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
0.000155
STAGE:2024-04-13 12:42:11 23865:23865 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-13 12:42:11 23865:23865 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     model_inference         0.00%       0.000us         0.00%       0.000us       0.000us      16.711ms        95.98%      16.711ms      16.711ms             1  
                     model_inference        65.76%      14.219ms        99.99%      21.620ms      21.620ms       0.000us         0.00%     817.000us     817.000us             1  
                        triton__0d1c         0.00%       0.000us         0.00%       0.000us       0.000us     600.000us         3.45%     600.000us       1.000us           600  
       triton_poi_fused_lift_fresh_0         5.61%       1.212ms         9.01%       1.949ms       9.745us     200.000us         1.15%     227.000us       1.135us           200  
           triton_poi_fused_arange_1         2.88%     622.000us         4.42%     955.000us       9.550us     100.000us         0.57%     114.000us       1.140us           100  
       triton_poi_fused_lift_fresh_2         2.61%     564.000us         4.00%     865.000us       8.650us     100.000us         0.57%     114.000us       1.140us           100  
       triton_poi_fused_lift_fresh_3         2.46%     531.000us         3.85%     832.000us       8.320us     100.000us         0.57%     114.000us       1.140us           100  
       triton_poi_fused_lift_fresh_4         2.42%     523.000us         3.81%     823.000us       8.230us     100.000us         0.57%     114.000us       1.140us           100  
       triton_poi_fused_lift_fresh_5         2.35%     509.000us         3.75%     810.000us       8.100us     100.000us         0.57%     114.000us       1.140us           100  
                         triton__0d1         0.00%       0.000us         0.00%       0.000us       0.000us     100.000us         0.57%     100.000us       1.000us           100  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 21.623ms
Self CUDA time total: 17.411ms

In the PyTorch Profiler, the call to those triton functions is shown.
But also the names triton__0d1c and triton__0d1 are shown (the ones which were only shown by nsys).

I am a bit confused. what are these triton__0d1c and triton__0d1 what is the corresponding code to them? Why triton_poi_fused_lift_fresh_0, triton_poi_fused_arange_1, … triton_poi_fused_lift_fresh_5 are not shown in nsys? What are they? Aren’t they kernels?


Not sure, but my hunch is the following:
The names triton_poi_fused_lift_fresh_0, triton_poi_fused_arange_1, … triton_poi_fused_lift_fresh_5 are just CPU side wrappers that call the triton kernels triton__0d1c and triton__0d1 to execute on the GPU. In other words, these 6 names map to the two triton kernels.

For example, if I modify the call function as follows:

def call(args):
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((1, ), (1, ), torch.bool)
        # Source Nodes: [tensor], Original ATen: [aten.lift_fresh]
        stream0 = get_raw_stream(0)
        triton_poi_fused_lift_fresh_0.run(buf0, 1, grid=grid(1), stream=stream0)
        buf1 = empty_strided_cuda((10, ), (1, ), torch.int64)
        # Source Nodes: [arange], Original ATen: [aten.arange]
        triton_poi_fused_arange_1.run(buf1, 10, grid=grid(10), stream=stream0)
        buf2 = empty_strided_cuda((), (), torch.int64)
        # Source Nodes: [tensor_1], Original ATen: [aten.lift_fresh]
        triton_poi_fused_lift_fresh_2.run(buf2, 1, grid=grid(1), stream=stream0)
        buf3 = empty_strided_cuda((), (), torch.int64)
        # Source Nodes: [tensor_2], Original ATen: [aten.lift_fresh]
        triton_poi_fused_lift_fresh_3.run(buf3, 1, grid=grid(1), stream=stream0)
        buf4 = empty_strided_cuda((), (), torch.int64)
        # Source Nodes: [tensor_3], Original ATen: [aten.lift_fresh]
        triton_poi_fused_lift_fresh_4.run(buf4, 1, grid=grid(1), stream=stream0)
        buf5 = empty_strided_cuda((), (), torch.bool)
        # Source Nodes: [tensor_4], Original ATen: [aten.lift_fresh]
        triton_poi_fused_lift_fresh_5.run(buf5, 1, grid=grid(1), stream=stream0)
        # Removing the call for mapping test
        # buf6 = empty_strided_cuda((), (), torch.bool)
        # # Source Nodes: [tensor_5], Original ATen: [aten.lift_fresh]
        # triton_poi_fused_lift_fresh_0.run(buf6, 1, grid=grid(1), stream=stream0)
        buf6 = None
    return (buf0, buf2, buf3, buf4, buf5, buf6, buf1, )

The PyTorch Profiler reports:

$ python3 output_code_pytorch_profiler.py 
STAGE:2024-04-13 14:38:21 24925:24925 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
0.000132
STAGE:2024-04-13 14:38:21 24925:24925 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-13 14:38:21 24925:24925 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     model_inference         0.00%       0.000us         0.00%       0.000us       0.000us      14.165ms        95.94%      14.165ms      14.165ms             1  
                     model_inference        67.14%      12.104ms        99.91%      18.011ms      18.011ms       0.000us         0.00%     701.000us     701.000us             1  
                        triton__0d1c         0.00%       0.000us         0.00%       0.000us       0.000us     500.000us         3.39%     500.000us       1.000us           500  
       triton_poi_fused_lift_fresh_0         3.74%     675.000us         5.89%       1.061ms      10.610us     100.000us         0.68%     114.000us       1.140us           100  
           triton_poi_fused_arange_1         3.37%     607.000us         5.06%     912.000us       9.120us     100.000us         0.68%     114.000us       1.140us           100  
       triton_poi_fused_lift_fresh_2         2.93%     528.000us         4.60%     830.000us       8.300us     100.000us         0.68%     114.000us       1.140us           100  
       triton_poi_fused_lift_fresh_3         2.80%     505.000us         4.47%     806.000us       8.060us     100.000us         0.68%     113.000us       1.130us           100  
       triton_poi_fused_lift_fresh_4         2.94%     530.000us         4.61%     831.000us       8.310us     100.000us         0.68%     113.000us       1.130us           100  
       triton_poi_fused_lift_fresh_5         2.82%     508.000us         4.50%     811.000us       8.110us     100.000us         0.68%     113.000us       1.130us           100  
                         triton__0d1         0.00%       0.000us         0.00%       0.000us       0.000us     100.000us         0.68%     100.000us       1.000us           100  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 18.028ms
Self CUDA time total: 14.765ms

The number of times triton_poi_fused_lift_fresh_0 was called reduces, and so does triton__0d1c by the same number.

Along the same direction there seems to be the following mapping:
triton_poi_fused_lift_fresh_0triton__0d1c
triton_poi_fused_arange_1triton__0d1
triton_poi_fused_lift_fresh_2triton__0d1c
triton_poi_fused_lift_fresh_3triton__0d1c
triton_poi_fused_lift_fresh_4triton__0d1c
triton_poi_fused_lift_fresh_5triton__0d1c

You could also try recompiling with torch._inductor.profiler_mark_wrapper_call = True which will turn on some annotations to improve the profiler output.