Hi @jansel , I wonder why inductor chooses Triton to generate CUDA kernels instead of other solutions like TVM / XLA?
Ah, my bad, missed the earlier discussion. Thanks for point that out @Lezcano !
So, if I understand correctly, the key point to not choose TVM is that Tensor IR requires more expert knowledge than Triton to get a good performance?
It seems the key point to choose triton is that it is focused on nvidia GPU optimizations and others(TVM/XLA) are not GPU bounded.
After digging on pytorch’s matmul triton template, I think it is rather genalized not bound to gpu. Hardware vendor can still port with triton and do their own transforms with this “tiled” language.
However, pytorch’s inductor implementation is indeed rather bound to gpu, which makes it harder to seperate the logic for inductor’s original role with it’s call to cuda apis.
Does the fact that dynamo is only for Linux and Mac mean that contributing to inductor is not possible on Windows?
@Idank96 I’d expect dynamo/inductor to work with Windows Subsystem for Linux (WSL). Though I don’t know anyone who has tried that.
One easy way to contribute on Windows would be to try that, fix any bugs (if any), and write instructions that other Windows users could follow.
If you want to try to add Windows support without WSL, we also welcome pull requests to add support for Windows.
It works on WSL, yes. I’ve been using the pytorch nightly build and it works fine.
@jansel What is unclear to me is how PT2.0 uses CUDA Graphs in torch.compile — the engineering details mainly.
I understand the flow:
PyTorch → TorchDynamo → TorchInductor → Triton → NVIDIA GPU
With the TORCH_COMPILE_DEBUG=1
env variable exported, I get the following few messages at the end of the logs:
...
...
[2024-03-25 11:07:02,184] torch._inductor.cudagraph_trees: [INFO] recording cudagraph tree for None
[2024-03-25 11:07:02,299] torch._inductor.cudagraph_trees: [DEBUG] Running warmup of function 0
[2024-03-25 11:07:02,462] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 0b0f632ae495eb55ca41fe06a33f3ed1
{'cassette': 0.9992951154708862, 'tape player': 0.0005148712079972029, 'cassette player': 0.0001740186125971377, 'radio': 1.0334849321225192e-05, 'CD player': 4.437319603312062e-06}
[2024-03-25 11:07:02,463] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* _splitext /usr/lib/python3.10/genericpath.py 121
[2024-03-25 11:07:02,462] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 0b0f632ae495eb55ca41fe06a33f3ed1
[2024-03-25 11:07:02,477] torch._inductor.cudagraph_trees: [DEBUG] Recording function 0 of graph recording id 0
[2024-03-25 11:07:02,585] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 0b0f632ae495eb55ca41fe06a33f3ed1
{'cassette': 0.9992951154708862, 'tape player': 0.0005148712079972029, 'cassette player': 0.0001740186125971377, 'radio': 1.0334849321225192e-05, 'CD player': 4.437319603312062e-06}
[2024-03-25 11:07:02,586] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 0b0f632ae495eb55ca41fe06a33f3ed1
[2024-03-25 11:07:02,604] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 0b0f632ae495eb55ca41fe06a33f3ed1
{'cassette': 0.9992951154708862, 'tape player': 0.0005148712079972029, 'cassette player': 0.0001740186125971377, 'radio': 1.0334849321225192e-05, 'CD player': 4.437319603312062e-06}
...
...
I called my inference three times. From the documentation about CUDAGraph Trees, it seems the first call warms-up, the second call records the CUDA Graph and the third call makes use of the captured CUDA Graph.
My question is, in the above flow, where does CUDAGraph fit in? And who drives it?
PyTorch → TorchDynamo → TorchInductor → Triton → CUDAGraph Capture (who drives it and the logic behind what to capture and what not to)→ NVIDIA GPU
I suppose a flow as the above, but again, the question of who drives the CUDAGraph Trees is a question. Is there any documentation regarding this? (Documentation helps before looking into the code, say torch/_inductor/cudagraph_trees.py
, lest I misinterpret anything)
cc @eellison who should be able to add more details.
At compile time it is implemented as:
- Compiler passes that analyize the model to see if cudagraphs will work (cudagraphs doesn’t support many things that must be checked for, mainly any work being done on the CPU)
- If it can be used, codegen a runtime wrapper around the model that applies cudagraphs
I was going through the inductor code base and found the following:
The above code inside compile_fx_inner
uses the cudagraphify
method, which ultimately (using CUDAGraph Trees) tries to make a CUDAGraph out of the compiled_graph.current_callable
.
As per the signature of cudagraphify
the first argument is:
So, compiled_graph.current_callable
is of type torch.fx.GraphModule
.
Now my question is, what is this compiled_graph.current_callable: torch.fx.GraphModule
?
In the compile_fx_inner
function, I find compiled_graph
defined as above.
Now gm
is again a torch.fx.GraphModule
.
So, questions here:
- What is the
torch.fx.GraphModule
instance thatcompile_fx_inner
takes as input? Is it afx_graph
module passed down from TorchDynamo? - What is the
torch.fx.GraphModule
instance that is passed tocudagraphify
incompile_fx_inner
? Is it an optimized GraphModule with triton kernels built into it? - Next, is my following understanding correct:
We parse PyTorch code to produce many FX Graphs (many because we might have graph breaks). These FX graphs from PyTorch code are GraphModules(), and TorchInductor makes triton code out of each FX Graph. So, in TorchInductor, we have a set of GraphModules (compiled to use triton), and then for each of these GraphModules, we decide whether to use CUDA Graphs.
Somewhat related @ezyang recently published a podcast on cudagraph trees:
I believe the
model: torch.fx.GraphModule
type annotation is wrong, it should be something like:
model: Callable[[List[Tensor]], List[Tensor]]
it is the compile wrapper code generated by TorchInductor that launches all the Triton kernels.
Is there any debug environment variable which can be set to have a glance at the wrapper code?
If we take the small example:
import os
os.environ["TORCH_COMPILE_DEBUG"] = "1"
import torch
@torch.compile(mode="reduce-overhead")
def toy_example(a, b):
return a * b
tensor1 = torch.randn(10, device="cuda")
tensor2 = torch.randn(10, device="cuda")
toy_example(tensor1, tensor2)
toy_example(tensor1, tensor2)
toy_example(tensor1, tensor2)
I get the following debug log:
[2024-03-25 22:31:29,930] torch._dynamo.eval_frame: [DEBUG] Saving dynamo config and hash for new compiled object(s). Hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:29,941] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:29,944] [4/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing toy_example /tmp/ipykernel_594021/431827302.py:2
[2024-03-25 22:31:29,944] [4/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2024-03-25 22:31:29,945] [4/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /tmp/ipykernel_594021/431827302.py:2 in toy_example
[2024-03-25 22:31:29,945] [4/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] @torch.compile(mode="reduce-overhead")
[2024-03-25 22:31:29,946] [4/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /tmp/ipykernel_594021/431827302.py:4 in toy_example
[2024-03-25 22:31:29,946] [4/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] return a * b
[2024-03-25 22:31:29,946] [4/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST a []
[2024-03-25 22:31:29,946] [4/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST b [LazyVariableTracker()]
[2024-03-25 22:31:29,947] [4/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE BINARY_MULTIPLY None [LazyVariableTracker(), LazyVariableTracker()]
[2024-03-25 22:31:29,947] [4/0] torch._dynamo.output_graph: [DEBUG] create_graph_input L_a_ L['a']
[2024-03-25 22:31:29,948] [4/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['a'] (10,) [<DimDynamic.STATIC: 2>] [None]
[2024-03-25 22:31:29,949] [4/0] torch._dynamo.output_graph: [DEBUG] create_graph_input L_b_ L['b']
[2024-03-25 22:31:29,950] [4/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['b'] (10,) [<DimDynamic.STATIC: 2>] [None]
[2024-03-25 22:31:29,952] [4/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RETURN_VALUE None [TensorVariable()]
[2024-03-25 22:31:29,952] [4/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing toy_example (RETURN_VALUE)
[2024-03-25 22:31:29,953] [4/0] torch._dynamo.symbolic_convert: [DEBUG] RETURN_VALUE triggered compile
[2024-03-25 22:31:29,953] [4/0] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /tmp/ipykernel_594021/431827302.py, line 4 in toy_example>], graph_break=False)
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG] ===== __compiled_fn_7 =====
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG] <eval_with_key>.67 class GraphModule(torch.nn.Module):
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG] def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_a_ = L_a_
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG] l_b_ = L_b_
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG] # File: /tmp/ipykernel_594021/431827302.py:4, code: return a * b
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG] mul = l_a_ * l_b_; l_a_ = l_b_ = None
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG] return (mul,)
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2024-03-25 22:31:29,954] [4/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] __compiled_fn_7 <eval_with_key>.67 opcode name target args kwargs
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] ------------- ------ ----------------------- ------------ --------
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder l_a_ L_a_ () {}
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder l_b_ L_b_ () {}
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] call_function mul <built-in function mul> (l_a_, l_b_) {}
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG] output output output ((mul,),) {}
[2024-03-25 22:31:29,955] [4/0] torch._dynamo.output_graph.__graph: [DEBUG]
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] TRACED GRAPH TENSOR SIZES
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] ===== __compiled_fn_7 =====
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] l_a_: (10,)
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] l_b_: (10,)
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] mul: (10,)
[2024-03-25 22:31:29,956] [4/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG]
[2024-03-25 22:31:29,957] [4/0] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO] TRACED GRAPH
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO] ===== Forward graph 7 =====
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO] <eval_with_key>.71 from /home/abhishek/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:511 in wrapped class <lambda>(torch.nn.Module):
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO] def forward(self, arg0_1: "f32[10]", arg1_1: "f32[10]"):
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO] # File: /tmp/ipykernel_594021/431827302.py:4, code: return a * b
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO] mul: "f32[10]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO] return (mul,)
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO]
[2024-03-25 22:31:29,966] [4/0] torch._functorch._aot_autograd.dispatch_and_compile_graph.__aot_graphs: [INFO]
[2024-03-25 22:31:29,968] [4/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 4
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO] TRACED GRAPH
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO] ===== AFTER POST GRAD =====
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO] <eval_with_key>.72 from /home/abhishek/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:511 in wrapped class <lambda>(torch.nn.Module):
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO] def forward(self, arg0_1: "f32[10]", arg1_1: "f32[10]"):
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO] # File: /tmp/ipykernel_594021/431827302.py:4, code: return a * b
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO] mul: "f32[10]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO] return (mul,)
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO]
[2024-03-25 22:31:29,972] [4/0] torch._inductor.compile_fx.__post_grad_graphs: [INFO]
[2024-03-25 22:31:29,974] [4/0] torch._inductor.graph: [DEBUG] lowering %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
[2024-03-25 22:31:29,974] [4/0] torch._inductor.graph: [DEBUG] lowering %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
[2024-03-25 22:31:29,975] [4/0] torch._inductor.graph: [DEBUG] lowering %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
[2024-03-25 22:31:29,976] [4/0] torch._inductor.graph: [DEBUG] via <function mul at 0x7f952b69ba30>
[2024-03-25 22:31:29,976] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] eval True == True [statically known]
[2024-03-25 22:31:29,977] [4/0] torch._inductor.graph: [DEBUG] lowering return (mul,)
[2024-03-25 22:31:29,978] [4/0] torch._inductor.graph: [DEBUG] Force channels last inputs for 0 conv for the current graph with id 4
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] scheduling ComputedBuffer(name='buf0', layout=FixedLayout('cuda', torch.float32, size=[10], stride=[1]), data=Pointwise(
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] 'cuda',
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] torch.float32,
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] def inner_fn(index):
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] i0 = index
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] tmp0 = ops.load(arg0_1, i0)
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] tmp1 = ops.load(arg1_1, i0)
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] tmp2 = tmp0 * tmp1
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] return tmp2
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] ,
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] ranges=[10],
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] origin_node=mul,
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] origins={mul}
[2024-03-25 22:31:29,981] [4/0] torch._inductor.scheduler: [DEBUG] ))
[2024-03-25 22:31:29,982] [4/0] torch._inductor.scheduler: [DEBUG] scheduling output buf0
[2024-03-25 22:31:29,983] [4/0] torch._inductor.debug: [INFO] Writing debug ir to /media/abhishek/Abhishek_NVMe/shweta_machine/trace_analysis/start-fc-gpu/fc-http-gpu-inference-torchhub-cv-mobilenet-v2/src/code/torch-compiled-reduce-overhead-dot-graphs/torch_compile_debug/run_2024_03_25_22_14_47_042588-pid_594021/torchinductor/model__7_inference_13.4/ir_pre_fusion.txt
[2024-03-25 22:31:29,985] [4/0] torch._inductor.debug: [INFO] Writing debug ir to /media/abhishek/Abhishek_NVMe/shweta_machine/trace_analysis/start-fc-gpu/fc-http-gpu-inference-torchhub-cv-mobilenet-v2/src/code/torch-compiled-reduce-overhead-dot-graphs/torch_compile_debug/run_2024_03_25_22_14_47_042588-pid_594021/torchinductor/model__7_inference_13.4/ir_post_fusion.txt
[2024-03-25 22:31:30,006] [4/0] torch._inductor.scheduler: [DEBUG] Generating code for node buf0 with estimated runtime 0.267857
[2024-03-25 22:31:30,010] [4/0] torch._inductor.codegen.triton: [DEBUG] Generating kernel code with kernel_name: triton_poi_fused_mul_0
[2024-03-25 22:31:30,645] [4/0] torch._inductor.triton_heuristics: [DEBUG] CachingAutotuner gets 1 configs
[2024-03-25 22:31:30,646] [4/0] torch._inductor.triton_heuristics: [DEBUG] XBLOCK: 16, num_warps: 1, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False
[2024-03-25 22:31:30,978] [4/0] torch._inductor.graph: [DEBUG] Output code written to: /tmp/torchinductor_abhishek/6i/c6imvaznfo4qm5nsa3hyxxmktztn4risj3dr3los4gefjjf62p63.py
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] Output code:
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from ctypes import c_void_p, c_long
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import torch
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import math
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import random
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import os
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import tempfile
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from math import inf, nan
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.hooks import run_intermediate_hooks
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.utils import maybe_profile
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.codegen.memory_planning import _align as align
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch import device, empty, empty_strided
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.codecache import AsyncCompile
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.select_algorithm import extern_kernels
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] aten = torch.ops.aten
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] inductor_ops = torch.ops.inductor
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] alloc_from_pool = torch.ops.inductor._alloc_from_pool
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] async_compile = AsyncCompile()
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] # kernel path: /tmp/torchinductor_abhishek/xz/cxzhsm3ysi6dh7mibspq3kqhgzqiqdq5qebnkt52iphjqmhccl66.py
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] # Source Nodes: [mul], Original ATen: [aten.mul]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] # mul => mul
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] triton_poi_fused_mul_0 = async_compile.triton('triton_', '''
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import triton
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import triton.language as tl
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.ir import ReductionHint
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.ir import TileHint
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.triton_heuristics import AutotuneHint, pointwise
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.utils import instance_descriptor
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor import triton_helpers
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] @pointwise(
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] size_hints=[16],
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] filename=__file__,
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_0', 'mutated_arg_names': []},
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] min_elem_per_thread=0
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] )
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] @triton.jit
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] xnumel = 10
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] xoffset = tl.program_id(0) * XBLOCK
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] xindex = xoffset + tl.arange(0, XBLOCK)[:]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] xmask = xindex < xnumel
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] x0 = xindex
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] tmp0 = tl.load(in_ptr0 + (x0), xmask)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] tmp1 = tl.load(in_ptr1 + (x0), xmask)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] tmp2 = tmp0 * tmp1
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] tl.store(out_ptr0 + (x0), tmp2, xmask)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] ''')
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import triton
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] import triton.language as tl
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.triton_heuristics import grid, start_graph, end_graph
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] async_compile.wait(globals())
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] del async_compile
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] def call(args):
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] arg0_1, arg1_1 = args
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] args.clear()
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] assert_size_stride(arg0_1, (10, ), (1, ))
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] assert_size_stride(arg1_1, (10, ), (1, ))
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] with torch.cuda._DeviceGuard(0):
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] torch.cuda.set_device(0) # no-op to ensure context
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] buf0 = empty((10, ), device='cuda', dtype=torch.float32)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] # Source Nodes: [mul], Original ATen: [aten.mul]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] stream0 = get_cuda_stream(0)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] triton_poi_fused_mul_0.run(arg0_1, arg1_1, buf0, 10, grid=grid(10), stream=stream0)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] del arg0_1
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] del arg1_1
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] return (buf0, )
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] def benchmark_compiled_module(times=10, repeat=10):
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._dynamo.testing import rand_strided
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.utils import print_performance
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] arg0_1 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.float32)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] arg1_1 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.float32)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] fn = lambda: call([arg0_1, arg1_1])
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] return print_performance(fn, times=times, repeat=repeat)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] if __name__ == "__main__":
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] from torch._inductor.wrapper_benchmark import compiled_module_main
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG] compiled_module_main('None', benchmark_compiled_module)
[2024-03-25 22:31:30,979] [4/0] torch._inductor.graph.__output_code: [DEBUG]
[2024-03-25 22:31:30,980] [4/0] torch._inductor.graph.__output_code: [INFO] Output code written to: /tmp/torchinductor_abhishek/6i/c6imvaznfo4qm5nsa3hyxxmktztn4risj3dr3los4gefjjf62p63.py
[2024-03-25 22:31:30,981] [4/0] torch._inductor.compile_fx: [DEBUG] FX codegen and compilation took 1.013s
[2024-03-25 22:31:30,984] [4/0] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 4
[2024-03-25 22:31:30,984] [4/0] torch._inductor.debug: [WARNING] model__7_inference_13 debug trace: /tmp/torchinductor_abhishek/6i/c6imvaznfo4qm5nsa3hyxxmktztn4risj3dr3los4gefjjf62p63.debug
[2024-03-25 22:31:30,987] [4/0] torch._dynamo.output_graph: [INFO] Step 2: done compiler function inductor
[2024-03-25 22:31:30,988] [4/0] torch.fx.experimental.symbolic_shapes: [INFO] produce_guards
[2024-03-25 22:31:30,989] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['a'].size()[0] 10 None
[2024-03-25 22:31:30,989] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['a'].stride()[0] 1 None
[2024-03-25 22:31:30,989] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['a'].storage_offset() 0 None
[2024-03-25 22:31:30,990] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['b'].size()[0] 10 None
[2024-03-25 22:31:30,990] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['b'].stride()[0] 1 None
[2024-03-25 22:31:30,990] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] track_symint L['b'].storage_offset() 0 None
[2024-03-25 22:31:30,991] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['a'].size()[0] == 10
[2024-03-25 22:31:30,991] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['a'].stride()[0] == 1
[2024-03-25 22:31:30,991] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['a'].storage_offset() == 0
[2024-03-25 22:31:30,992] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['b'].size()[0] == 10
[2024-03-25 22:31:30,992] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['b'].stride()[0] == 1
[2024-03-25 22:31:30,993] [4/0] torch.fx.experimental.symbolic_shapes: [DEBUG] Skipping guard L['b'].storage_offset() == 0
[2024-03-25 22:31:30,993] [4/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2024-03-25 22:31:30,995] [4/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['a'], '_dynamo_dynamic_indices') == False # return a * b # mp/ipykernel_594021/431827302.py:4 in toy_example
[2024-03-25 22:31:30,995] [4/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['b'], '_dynamo_dynamic_indices') == False # return a * b # mp/ipykernel_594021/431827302.py:4 in toy_example
[2024-03-25 22:31:30,996] [4/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:379 in init_ambient_guards
[2024-03-25 22:31:30,997] [4/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(140282244472576)) # _dynamo/output_graph.py:385 in init_ambient_guards
[2024-03-25 22:31:30,998] [4/0] torch._dynamo.guards.__guards: [DEBUG] ___compile_config_hash() == '9fab733732f45cfca18c83a96ae70468' # _dynamo/output_graph.py:387 in init_ambient_guards
[2024-03-25 22:31:30,998] [4/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['a'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[10], stride=[1]) # return a * b # mp/ipykernel_594021/431827302.py:4 in toy_example
[2024-03-25 22:31:30,999] [4/0] torch._dynamo.guards.__guards: [DEBUG] check_tensor(L['b'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[10], stride=[1]) # return a * b # mp/ipykernel_594021/431827302.py:4 in toy_example
[2024-03-25 22:31:31,001] torch._inductor.cudagraph_trees: [INFO] recording cudagraph tree for None
[2024-03-25 22:31:31,082] torch._inductor.cudagraph_trees: [DEBUG] Running warmup of function 0
[2024-03-25 22:31:31,083] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:31,084] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:31,084] torch._inductor.cudagraph_trees: [DEBUG] Recording function 0 of graph recording id 0
[2024-03-25 22:31:31,157] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:31,157] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468
[2024-03-25 22:31:31,158] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: 9fab733732f45cfca18c83a96ae70468
Is that wrapper code shown here? If so, which one is it?
Yes the wrapper code is this function:
There is also TORCH_LOGS=+output_code
(end of the original post got cropped so posting the end as a reply to the original post)
I guess it is code that is dumped to the file mentioned: Output code written to: /tmp/torchinductor_abhishek/6i/c6imvaznfo4qm5nsa3hyxxmktztn4risj3dr3los4gefjjf62p63.py
And it happens to be:
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_abhishek/xz/cxzhsm3ysi6dh7mibspq3kqhgzqiqdq5qebnkt52iphjqmhccl66.py
# Source Nodes: [mul], Original ATen: [aten.mul]
# mul => mul
triton_poi_fused_mul_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@pointwise(
size_hints=[16],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_0', 'mutated_arg_names': []},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 10
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl.load(in_ptr1 + (x0), xmask)
tmp2 = tmp0 * tmp1
tl.store(out_ptr0 + (x0), tmp2, xmask)
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (10, ), (1, ))
assert_size_stride(arg1_1, (10, ), (1, ))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty((10, ), device='cuda', dtype=torch.float32)
# Source Nodes: [mul], Original ATen: [aten.mul]
stream0 = get_cuda_stream(0)
triton_poi_fused_mul_0.run(arg0_1, arg1_1, buf0, 10, grid=grid(10), stream=stream0)
del arg0_1
del arg1_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.float32)
arg1_1 = rand_strided((10, ), (1, ), device='cuda:0', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
And the function of interest is call
in the above module, right (the compiled wrapper code generated by TorchInductor)? Based on:
Yes, cudagraphify
is wrapping call
@jansel Was having trouble figuring out something. Suppose we have an output_code.py
file generated automatically by setting the TORCH_COMPILE_DEBUG flag to 1. Now this output_code.py
has say definition of x
number of triton kernels and the call
function calls each of these x
kernels, but nsys profile of the output_code.py
file as nsys profile -o python3 output_code.py
, reports execution of say y
number of kernels, where y < x
.
For example with the following output_code.py
file (torchinductor/model__4_inference_10.1/):
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_abhishek/jr/cjrdra55i6oq3npg2ofcxhhxqozdch6g3eqzefwk7iyia2vnemp7.py
# Source Nodes: [tensor], Original ATen: [aten.lift_fresh]
# tensor => full_default
triton_poi_fused_lift_fresh_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[1],
filename=__file__,
triton_meta={'signature': {0: '*i1', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {1: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=(1,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_lift_fresh_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.full([1], False, tl.int1)
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp0, None)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_abhishek/sa/csa63dnba3cwdups6w77k5ehbwe2srxwqsgnkaxjxvexrisqpofw.py
# Source Nodes: [arange], Original ATen: [aten.arange]
# arange => iota
triton_poi_fused_arange_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[16],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_arange_1', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 10
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = x0
tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_abhishek/yq/cyqzktrq7yalywzrlmyzqcyugfkskjk6ateieldildbva3hwulja.py
# Source Nodes: [tensor_1], Original ATen: [aten.lift_fresh]
# tensor_1 => full_default_1
triton_poi_fused_lift_fresh_2 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[1],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {1: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=(1,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_lift_fresh_2', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.full([1], 0, tl.int64)
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp0, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_abhishek/vk/cvk4asato2zcv7gno273d5iju2gu3zsd35tiy3m3xl2yamm5zrfe.py
# Source Nodes: [tensor_2], Original ATen: [aten.lift_fresh]
# tensor_2 => full_default_2
triton_poi_fused_lift_fresh_3 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[1],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {1: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=(1,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_lift_fresh_3', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.full([1], 1, tl.int64)
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp0, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_abhishek/uy/cuybdehcpe364i7q5qhx5o27nuicncddeea6lcknqr56otci23gt.py
# Source Nodes: [tensor_3], Original ATen: [aten.lift_fresh]
# tensor_3 => full_default_3
triton_poi_fused_lift_fresh_4 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[1],
filename=__file__,
triton_meta={'signature': {0: '*i64', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {1: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=(1,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_lift_fresh_4', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.full([1], 5, tl.int64)
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp0, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_abhishek/qr/cqrvjsoxqqbmo5xjphzpeur2avys3e4j7komqlq5ri3kjdjeqtyl.py
# Source Nodes: [tensor_4], Original ATen: [aten.lift_fresh]
# tensor_4 => full_default_4
triton_poi_fused_lift_fresh_5 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
size_hints=[1],
filename=__file__,
triton_meta={'signature': {0: '*i1', 1: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {1: 1}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=(1,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_lift_fresh_5', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '41a3d5b7867eec3e9749fd2d5d072c8778e695c90c8a6fd62f1546cb8261bd00'},
min_elem_per_thread=0
)
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.full([1], True, tl.int1)
tl.store(out_ptr0 + (tl.full([XBLOCK], 0, tl.int32)), tmp0, None)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((1, ), (1, ), torch.bool)
# Source Nodes: [tensor], Original ATen: [aten.lift_fresh]
stream0 = get_raw_stream(0)
triton_poi_fused_lift_fresh_0.run(buf0, 1, grid=grid(1), stream=stream0)
buf1 = empty_strided_cuda((10, ), (1, ), torch.int64)
# Source Nodes: [arange], Original ATen: [aten.arange]
triton_poi_fused_arange_1.run(buf1, 10, grid=grid(10), stream=stream0)
buf2 = empty_strided_cuda((), (), torch.int64)
# Source Nodes: [tensor_1], Original ATen: [aten.lift_fresh]
triton_poi_fused_lift_fresh_2.run(buf2, 1, grid=grid(1), stream=stream0)
buf3 = empty_strided_cuda((), (), torch.int64)
# Source Nodes: [tensor_2], Original ATen: [aten.lift_fresh]
triton_poi_fused_lift_fresh_3.run(buf3, 1, grid=grid(1), stream=stream0)
buf4 = empty_strided_cuda((), (), torch.int64)
# Source Nodes: [tensor_3], Original ATen: [aten.lift_fresh]
triton_poi_fused_lift_fresh_4.run(buf4, 1, grid=grid(1), stream=stream0)
buf5 = empty_strided_cuda((), (), torch.bool)
# Source Nodes: [tensor_4], Original ATen: [aten.lift_fresh]
triton_poi_fused_lift_fresh_5.run(buf5, 1, grid=grid(1), stream=stream0)
buf6 = empty_strided_cuda((), (), torch.bool)
# Source Nodes: [tensor_5], Original ATen: [aten.lift_fresh]
triton_poi_fused_lift_fresh_0.run(buf6, 1, grid=grid(1), stream=stream0)
return (buf0, buf2, buf3, buf4, buf5, buf6, buf1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
fn = lambda: call([])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
As we can see we have definitions of six triton kernels in the file. And the call
function makes call to each of these six kernels. But a nsys profile reports the following:
** CUDA Kernel Launch & Exec Time Summary (cuda_kern_exec_sum):
PID TID DevId Count QCount TAvg (ns) TMed (ns) TMin (ns) TMax (ns) TStdDev (ns) AAvg (ns) AMed (ns) AMin (ns) AMax (ns) AStdDev (ns) QAvg (ns) QMed (ns) QMin (ns) QMax (ns) QStdDev (ns) KAvg (ns) KMed (ns) KMin (ns) KMax (ns) KStdDev (ns) API Name Kernel Name
----- ----- ----- ----- ------ --------- --------- --------- --------- ------------ --------- --------- --------- --------- ------------ --------- --------- --------- --------- ------------ --------- --------- --------- --------- ------------ -------------- ------------
6,093 6,093 0 600 559 5,935.2 5,614.0 5,244 20,308 1,314.0 3,956.5 3,442.0 3,088 18,173 1,546.9 1,002.9 1,013.0 457 1,635 126.7 1,119.4 1,120.0 1,087 1,760 34.1 cuLaunchKernel triton__0d1c
6,093 6,093 0 100 92 5,790.7 5,705.5 5,479 7,642 323.6 3,855.3 3,571.5 3,322 7,405 813.5 969.1 974.0 325 1,241 100.4 1,134.1 1,120.0 1,088 1,408 34.5 cuLaunchKernel triton__0d1
Only two kernels are reported, triton_0d1c
and triton_0d1
.
Why only these two kernels? I expected six of them. And how are these kernels named? I just want to know the corresponding triton code of the kernel (running on the GPU as reported by nsys). Kind of a mapping.
Can you please explain, what is happening here?
@jansel I wrapped the code as follows to make use of PyTorch Profiler:
...
...
from torch.profiler import profile, record_function, ProfilerActivity
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
with profile(activities=[
ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
with record_function("model_inference"):
compiled_module_main('None', benchmark_compiled_module)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
The output which is shown by PyTorch Profiler:
$ python3 output_code_pytorch_profiler.py
STAGE:2024-04-13 12:42:11 23865:23865 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
0.000155
STAGE:2024-04-13 12:42:11 23865:23865 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-13 12:42:11 23865:23865 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
model_inference 0.00% 0.000us 0.00% 0.000us 0.000us 16.711ms 95.98% 16.711ms 16.711ms 1
model_inference 65.76% 14.219ms 99.99% 21.620ms 21.620ms 0.000us 0.00% 817.000us 817.000us 1
triton__0d1c 0.00% 0.000us 0.00% 0.000us 0.000us 600.000us 3.45% 600.000us 1.000us 600
triton_poi_fused_lift_fresh_0 5.61% 1.212ms 9.01% 1.949ms 9.745us 200.000us 1.15% 227.000us 1.135us 200
triton_poi_fused_arange_1 2.88% 622.000us 4.42% 955.000us 9.550us 100.000us 0.57% 114.000us 1.140us 100
triton_poi_fused_lift_fresh_2 2.61% 564.000us 4.00% 865.000us 8.650us 100.000us 0.57% 114.000us 1.140us 100
triton_poi_fused_lift_fresh_3 2.46% 531.000us 3.85% 832.000us 8.320us 100.000us 0.57% 114.000us 1.140us 100
triton_poi_fused_lift_fresh_4 2.42% 523.000us 3.81% 823.000us 8.230us 100.000us 0.57% 114.000us 1.140us 100
triton_poi_fused_lift_fresh_5 2.35% 509.000us 3.75% 810.000us 8.100us 100.000us 0.57% 114.000us 1.140us 100
triton__0d1 0.00% 0.000us 0.00% 0.000us 0.000us 100.000us 0.57% 100.000us 1.000us 100
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 21.623ms
Self CUDA time total: 17.411ms
In the PyTorch Profiler, the call to those triton functions is shown.
But also the names triton__0d1c
and triton__0d1
are shown (the ones which were only shown by nsys).
I am a bit confused. what are these triton__0d1c
and triton__0d1
what is the corresponding code to them? Why triton_poi_fused_lift_fresh_0
, triton_poi_fused_arange_1
, … triton_poi_fused_lift_fresh_5
are not shown in nsys? What are they? Aren’t they kernels?
Not sure, but my hunch is the following:
The names triton_poi_fused_lift_fresh_0
, triton_poi_fused_arange_1
, … triton_poi_fused_lift_fresh_5
are just CPU side wrappers that call the triton kernels triton__0d1c
and triton__0d1
to execute on the GPU. In other words, these 6 names map to the two triton kernels.
For example, if I modify the call
function as follows:
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((1, ), (1, ), torch.bool)
# Source Nodes: [tensor], Original ATen: [aten.lift_fresh]
stream0 = get_raw_stream(0)
triton_poi_fused_lift_fresh_0.run(buf0, 1, grid=grid(1), stream=stream0)
buf1 = empty_strided_cuda((10, ), (1, ), torch.int64)
# Source Nodes: [arange], Original ATen: [aten.arange]
triton_poi_fused_arange_1.run(buf1, 10, grid=grid(10), stream=stream0)
buf2 = empty_strided_cuda((), (), torch.int64)
# Source Nodes: [tensor_1], Original ATen: [aten.lift_fresh]
triton_poi_fused_lift_fresh_2.run(buf2, 1, grid=grid(1), stream=stream0)
buf3 = empty_strided_cuda((), (), torch.int64)
# Source Nodes: [tensor_2], Original ATen: [aten.lift_fresh]
triton_poi_fused_lift_fresh_3.run(buf3, 1, grid=grid(1), stream=stream0)
buf4 = empty_strided_cuda((), (), torch.int64)
# Source Nodes: [tensor_3], Original ATen: [aten.lift_fresh]
triton_poi_fused_lift_fresh_4.run(buf4, 1, grid=grid(1), stream=stream0)
buf5 = empty_strided_cuda((), (), torch.bool)
# Source Nodes: [tensor_4], Original ATen: [aten.lift_fresh]
triton_poi_fused_lift_fresh_5.run(buf5, 1, grid=grid(1), stream=stream0)
# Removing the call for mapping test
# buf6 = empty_strided_cuda((), (), torch.bool)
# # Source Nodes: [tensor_5], Original ATen: [aten.lift_fresh]
# triton_poi_fused_lift_fresh_0.run(buf6, 1, grid=grid(1), stream=stream0)
buf6 = None
return (buf0, buf2, buf3, buf4, buf5, buf6, buf1, )
The PyTorch Profiler reports:
$ python3 output_code_pytorch_profiler.py
STAGE:2024-04-13 14:38:21 24925:24925 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
0.000132
STAGE:2024-04-13 14:38:21 24925:24925 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-04-13 14:38:21 24925:24925 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
model_inference 0.00% 0.000us 0.00% 0.000us 0.000us 14.165ms 95.94% 14.165ms 14.165ms 1
model_inference 67.14% 12.104ms 99.91% 18.011ms 18.011ms 0.000us 0.00% 701.000us 701.000us 1
triton__0d1c 0.00% 0.000us 0.00% 0.000us 0.000us 500.000us 3.39% 500.000us 1.000us 500
triton_poi_fused_lift_fresh_0 3.74% 675.000us 5.89% 1.061ms 10.610us 100.000us 0.68% 114.000us 1.140us 100
triton_poi_fused_arange_1 3.37% 607.000us 5.06% 912.000us 9.120us 100.000us 0.68% 114.000us 1.140us 100
triton_poi_fused_lift_fresh_2 2.93% 528.000us 4.60% 830.000us 8.300us 100.000us 0.68% 114.000us 1.140us 100
triton_poi_fused_lift_fresh_3 2.80% 505.000us 4.47% 806.000us 8.060us 100.000us 0.68% 113.000us 1.130us 100
triton_poi_fused_lift_fresh_4 2.94% 530.000us 4.61% 831.000us 8.310us 100.000us 0.68% 113.000us 1.130us 100
triton_poi_fused_lift_fresh_5 2.82% 508.000us 4.50% 811.000us 8.110us 100.000us 0.68% 113.000us 1.130us 100
triton__0d1 0.00% 0.000us 0.00% 0.000us 0.000us 100.000us 0.68% 100.000us 1.000us 100
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 18.028ms
Self CUDA time total: 14.765ms
The number of times triton_poi_fused_lift_fresh_0
was called reduces, and so does triton__0d1c
by the same number.
Along the same direction there seems to be the following mapping:
triton_poi_fused_lift_fresh_0
→ triton__0d1c
triton_poi_fused_arange_1
→ triton__0d1
triton_poi_fused_lift_fresh_2
→ triton__0d1c
triton_poi_fused_lift_fresh_3
→ triton__0d1c
triton_poi_fused_lift_fresh_4
→ triton__0d1c
triton_poi_fused_lift_fresh_5
→ triton__0d1c
You could also try recompiling with torch._inductor.profiler_mark_wrapper_call = True
which will turn on some annotations to improve the profiler output.