TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation

In Next Steps for PyTorch Compilers, we laid out a vision of deploying eager mode PyTorch to more production settings and investing in using compilers to make eager mode faster and easier to maintain. This move away from graph mode makes some things a lot harder. For example, simple fusions that cross operator boundaries are at first glance not possible without users modifying their models. Lazy Tensors is one way to recapture these optimization opportunities. However, because it exists below the dispatcher, it cannot remove the overheads from Python and the upper levels of PyTorch stack — so it may not be a good choice for smaller, overhead-bound models.

TorchDynamo is an early experiment that radically rethinks the approach for recapturing these optimization opportunities. It hooks into the frame evaluation API in CPython to dynamically modify Python bytecode right before it is executed. This is analogous to what DynamoRIO does by dynamically modifying x86 machine code. TorchDynamo dynamically rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with a user-defined compiler. It creates this FX Graph through bytecode analysis, not tracing, and is designed to generating smaller graph fragments that can be mixed with Python execution. This approach has many advantages:

  • It supports all Python because it can easily fall back to running the original bytecode. It depends on a fast eager mode in order to work properly, because the goal is to enhance eager mode rather than replace it.
  • It is extremely low overhead, where it is possible to remove Python overheads from the original program by intercepting things at the very top of the stack.
  • It does not introduce any added latency by deferring execution.

How TorchDynamo works

The figure above shows how TorchDynamo changes the behavior of CPython. TorchDynamo installs a custom eval frame function which performs dynamic bytecode analysis and transformation. The transformations insert calls to compiled FX Graphs into the bytecode. It protects reuse of these compiled artifacts by guards to ensure soundness.

To make this process more clear, let’s go through an example. Consider this toy code:

def fn(a, b):
    x = a + b
    x = x / 2.0
    if x.sum() < 0:
        return x * -1.0
    return x
with torchdynamo.optimize(custom_compiler):   
   fn(torch.randn(10), torch.randn(10))

This toy example, results in the following original Python bytecode for fn():

 0  LOAD_FAST 0 (a)
 2  LOAD_FAST 1 (b)
 6  STORE_FAST 2 (x)

 8  LOAD_FAST 2 (x)
 10 LOAD_CONST 1 (2.0)
 14 STORE_FAST 2 (x)

 16 LOAD_FAST 2 (x)
 18 LOAD_METHOD 0 (sum)
 22 LOAD_CONST 2 (0)
 24 COMPARE_OP 0 (<)

 28 LOAD_FAST 2 (x)
 30 LOAD_CONST 3 (-1.0)

 36 LOAD_FAST 2 (x)

TorchDynamo dynamically rewrites that bytecode as follows:

 0  LOAD_GLOBAL 1 (__compiled_fn_0)
 2  LOAD_FAST 0 (a)
 4  LOAD_FAST 1 (b)
 10 STORE_FAST 2 (x)
 14 LOAD_GLOBAL 2 (__compiled_fn_1)
 16 LOAD_FAST 2 (x)

 22 LOAD_FAST 2 (x)

This new bytecode calls two compiled FX graphs below. One can see that the control flow splits the program into two graphs.

opcode         name     target                       args              kwargs
-------------  -------  ---------------------------  ----------------  --------
placeholder    a_0      a_0                          ()                {}
placeholder    b_1      b_1                          ()                {}
call_function  add      <built-in function add>      (a_0, b_1)        {}
call_function  truediv  <built-in function truediv>  (add, 2.0)        {}
call_method    sum_1    sum                          (truediv,)        {}
call_function  lt       <built-in function lt>       (sum_1, 0)        {}
output         output   output                       ((truediv, lt),)  {}

opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    x_4     x_4                      ()           {}
call_function  mul     <built-in function mul>  (x_4, -1.0)  {}
output         output  output                   (mul,)       {}

Finally, TorchDynamo generates two guards:

  • Local arg “a” must be a torch.Tensor
  • Local arg “b” must be a torch.Tensor

Failure of either of these guards triggers re-analysis and transformation.

If TorchDynamo were to encounter calls to non-PyTorch things, or some fancy Python structures it would leave those in the original bytecode. Thus, TorchDynamo opportunistically finds opportunities for optimization, without sacrificing the Python user experience.


Here is how the current API works:

def custom_compiler(graph: torch.fx.GraphModule) → Callable:
    # do cool compiler optimizations here
    return graph.forward
with torchdynamo.optimize(custom_compiler):
    # any PyTorch code
    # custom_compiler() is called to optimize extracted fragments
    # should reach a fixed point where nothing new is compiled
# Optionally:
    # any PyTorch code
    # previosly compiled artifacts are reused
    # provides a quiescence guarantee, without compiles

You define your compiler function (which compiles an FX graph to a python callable), then wrap the code you want TorchDynamo to optimize in a torchdynamo.optimize context manager. This should be all you need. In the cases where you want to make sure there are no added compile warmup code, we provide to reuse prior optimizations from torchdynamo.optimize, but not trigger any new compiles.

Early results

This project is still very early, so we haven’t tried applying optimizations yet and have been focusing on correctness, overhead, and coverage. We measured on 35 TorchBench models, using Python 3.8, and an Intel CPU. Raw results are here.

To summarize the results in the key focus areas:

  • Correctness: 100%
    • Correctness is by far this most important metric. It is how many models run and produce the right answer. The goal here is to make zero sacrifices to user experience and exist on the maximum usability end of the Pareto optimal curve. This is still early work, so there are surely bugs/gaps — though running on TorchBench gives some confidence it works for a wide variety of models.
  • Overhead: <1% average
    • Checking guards and patching frame objects adds some overheads. On the measured models overheads are under 1% for most models, and actually speeds many models up slightly. This is without doing any optimizations in the FX compiler function, so we are paying all the costs but getting no benefits. This metric is worst case scenario. Later on a focus will be using TorchDynamo to apply optimizations and get speedups.
  • Coverage: 60% of ops, 64% of time
    • The final metric is how many ops TorchDynamo captures, versus total ops in the whole model. This early version is able to capture 60% of all ops (which account for 64% of time). There are some models where 0% is captured, other models where 100% is captured, and most models are somewhere in between. There are still many missing features to add, so this is the current main area of focus for improvement.

Next Steps

Check out the source code here. This is still an experimental prototype, so use at your own risk. If you want to contribute please reach out to us.

There is still a ton of work left to do, so stay tuned for future updates that (hopefully) include higher coverage and some applications resulting in speedups!