I was going through the inductor code base and found the following:
The above code inside compile_fx_inner uses the cudagraphify method, which ultimately (using CUDAGraph Trees) tries to make a CUDAGraph out of the compiled_graph.current_callable.
As per the signature of cudagraphify the first argument is:
So, compiled_graph.current_callable is of type torch.fx.GraphModule.
Now my question is, what is this compiled_graph.current_callable: torch.fx.GraphModule?
In the compile_fx_inner function, I find compiled_graph defined as above.
Now gm is again a torch.fx.GraphModule.
So, questions here:
- What is the
torch.fx.GraphModuleinstance thatcompile_fx_innertakes as input? Is it afx_graphmodule passed down from TorchDynamo? - What is the
torch.fx.GraphModuleinstance that is passed tocudagraphifyincompile_fx_inner? Is it an optimized GraphModule with triton kernels built into it? - Next, is my following understanding correct:
We parse PyTorch code to produce many FX Graphs (many because we might have graph breaks). These FX graphs from PyTorch code are GraphModules(), and TorchInductor makes triton code out of each FX Graph. So, in TorchInductor, we have a set of GraphModules (compiled to use triton), and then for each of these GraphModules, we decide whether to use CUDA Graphs.