A TorchDynamo trace time ablation study

TorchDynamo compile times have been a focus lately. In this post I wanted to give a breakdown of where they are coming from. I will show some profile results and perform an ablation study where I try to estimate how much compile time improvements we can see from various types of optimizations. In the end the optimizations here point to a path towards a 10x speedup in TorchDynamo trace times.

The microbenchmark

Here is the microbenchmark this post will focus on today.

@torch.compile(backend="eager", fullgraph=True)
def tensor_dicts(inputs):
    result = torch.zeros_like(inputs[0])
    for k1 in inputs:
        for k2 in inputs:
            result = result + torch.sin(inputs[k1] + inputs[k2])
    return result

tensor_dicts({i: torch.randn(1) for in in range(100)})

This benchmark focuses on dictionary iteration and access in dynamo and it creates ~20000 ops in the output graph.

This benchmark initially runs in 8.9s on my local desktop. Note that this is imporved from 9.6s before this stack of PRs landed, which sped up TorchDynamo bytecode processing.

A call to cprofile+snakeviz yields the attached profile.

We can see:

  • 48% of time is spent in BINARY_OP handling (the add)
  • 33% of time is spent in CALL handling (torch.sin)
  • 7% of time is spent in RETURN_VALUE (final graph processing)
  • 3% of time is spent in LOAD_ATTR (reading values from the dict)

Fake tensors

Looking deeper into the profile above. Both BINARY_OP and CALL spend a lot of time in get_fake_value which is doing fake tensor shape propagation. We can ablate this compile time cost away with the following patch:

diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index cf4082630ff..1c07f606b64 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -1650,6 +1650,10 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
     if "example_value" in node.meta and is_fake(node.meta["example_value"]):
         return node.meta["example_value"]
 
+    # Hack to bypass output shape/dtype/device computation.
+    # Assume everything returns the same shape as its first input.
+    return node.args[0].meta["example_value"]
+
     args, kwargs = get_fake_values_from_nodes(
         tx, (node.args, node.kwargs), allow_non_graph_fake
     )

The patch improves compile times from 8.9s to 4.8s, so a 1.85x speedup is possible if we can either make fake tensors faster, or replace fake tensors with a faster way to compute output shapes of ops. Perhaps we could cache this shape computation at the TensorVariable level and not need to call fake tensors most of the time.

FX Node creation

Next I noticed a lot of time being spent in create_proxy and wrap_fx_proxy which involve inserting new nodes into the FX graph, tracking line numbers (for debugging), and specializing on the result. I ablated away this compile time cost with:

diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index 21ce272c997..357c4d7010e 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -824,6 +824,21 @@ class BuiltinVariable(VariableTracker):
         return builtin_dipatch
 
     def _handle_insert_op_in_graph(self, tx, args, kwargs):
+        # Hack to skip creating a new FX node, just return a copy of args[0]
+        if isinstance(args[0], TensorVariable):
+            value = args[0]
+            return TensorVariable(
+                value.proxy,
+                dtype=value.dtype,
+                device=value.device,
+                layout=value.layout,
+                ndim=value.ndim,
+                requires_grad=value.requires_grad,
+                is_quantized=value.is_quantized,
+                is_sparse=value.is_sparse,
+                class_type=value,
+            )
+
         from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
 
         if kwargs and not self.tensor_args(*args, *kwargs.values()):
diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py
index 5571193e369..fc3d8bcd259 100644
--- a/torch/_dynamo/variables/torch.py
+++ b/torch/_dynamo/variables/torch.py
@@ -673,6 +673,21 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
         from . import ConstantVariable, SymNodeVariable, TensorVariable
         from .builder import wrap_fx_proxy
 
+        # Hack to skip creating a new FX node, just return a copy of args[0]
+        if isinstance(args[0], TensorVariable):
+            value = args[0]
+            return TensorVariable(
+                value.proxy,
+                dtype=value.dtype,
+                device=value.device,
+                layout=value.layout,
+                ndim=value.ndim,
+                requires_grad=value.requires_grad,
+                is_quantized=value.is_quantized,
+                is_sparse=value.is_sparse,
+                class_type=value,
+            )
+
         if self.can_constant_fold_through() and check_unspec_or_constant_args(
             args, kwargs
         ):

This patch further improves things from 4.8s to 1.2s, bringing us to a 7.4x hypothetical speedup over the original. There seems to be a lot of room for compile time speedups here.

Dynamo overheads

Looking at the updated profile. Now we are starting to see overheads of the actual dynamo tracing (the processing of Python bytecode) start to matter.

At this point the speedups become smaller. For example this patch:

    def getitem_const(self, arg: VariableTracker):
+        return self.items[0]
         key = ConstDictVariable._HashableTracker(arg)
         if key not in self.items:
             raise KeyError(arg.value)

Improves things from 1.2s to 1.0s, and this patch:

diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 973d79288d5..7a94912299b 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -722,8 +722,9 @@ class InstructionTranslatorBase(
         self.current_instruction = inst = self.instructions[ip]
         self.instruction_pointer = ip + 1
 
-        if inst.starts_line:
-            self.starts_line(inst.starts_line)
+        # skip tracking current line number for stack traces
+        # if inst.starts_line:
+        #     self.starts_line(inst.starts_line)
 
         if (
             not self.stack

Improves things further to 0.9s.

Concluding thoughts

TorchDynamo compile time is highly optimizable! Top of this list is coming up with a faster way to compute output metadata for PyTorch functions. The next biggest win would be faster graph construction and debug data. The actual dynamo processing of bytecodes doesn’t seem to be the main bottleneck in dynamo trace times.

4 Likes