I’m trying to reproduce the “Without inlining” ablation from the PyTorch 2.0 paper (the table that reports speedups for “All TorchInductor optimizations”, and then “Without inlining”, “Without fusion”, and “Without fusion & inlining”, etc.) . My research goal is to turn off inlining (As well as other optimizations) in TorchInductor while keeping the rest of the pipeline intact, so I can attribute speedups precisely.
You should verify it is working by compiling something like torch.sin(torch.cos(x)). With inlining you shoul see one SchedulerNode. Without it you should see two.
In the code “realized” buffers are not inlined (while unrealized ones are inlined).
Thanks for the pointers! I tried setting the inlining knobs to zero, but I still see sin(cos(x)) ending up in a single SchedulerNode (and one kernel mapping to ["cos","sin"]) in the debug dumps.
The script I used:
import torch, torch._inductor.config as cfg
cfg.realize_reads_threshold = 0
cfg.realize_opcount_threshold = 0
cfg.trace.enabled = True
def f(x):
return torch.sin(torch.cos(x))
x = torch.randn(1_000_000, device="cpu")
g = torch.compile(f, fullgraph=True)(x)
I ran this using the following flags:
TORCH_COMPILE_DEBUG=1 TORCHINDUCTOR_FX_GRAPH_CACHE=0 python verify_inline.py
When I diff the two latest debug runs, the only difference I see is that one run’s generated fx_graph_runnable.py includes the threshold assignments and the other does not:
diff -r torch_compile_debug/run_2025_08_22_11_27_35_941034-pid_8665/ \
Here is the snippet of inductor_provenance_tracking_node_mappings.json
{“preToPost”: {“cos”: [“cos”], “sin”: [“sin”]}, “postToPre”: {“cos”: [“cos”], “sin”: [“sin”]}, “cppCodeToPost”: {“cpp_fused_cos_sin_0”: [“sin”, “cos”]}, “postToCppCode”: {“sin”: [“cpp_fused_cos_sin_0”], “cos”: [“cpp_fused_cos_sin_0”]}}
Questions:
Are these thresholds read early enough that setting them inside the generated runnable wouldn’t affect lowering-time inlining? (i.e., do they need to be set even earlier than I am?)
Is there a specific file/decision point in lowering where “unrealized → inline” vs “realized → materialize” happens that I can instrument to confirm we’re hitting the intended branch?
Here is a workaround (the realize_reads_threshold works for users>1):
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 31be050ab28..8d36cbbd4fe 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -1712,7 +1712,9 @@ class GraphLowering(torch.fx.Interpreter):
# Realize if (1) any user need inputs realized, or (2) there is
# already too many reads and rematerializing can be bad.
num_users = len(OrderedSet(n.users))
- if num_users > 1 and isinstance(result, TensorBox):
+ if num_users > 0 and isinstance(result, TensorBox):
+ result.realize()
+
for user in n.users:
if user.target in needs_realized_inputs:
result.realize_hint()