TorchDynamo Update 5: Improved Capture & Bigger Graphs

Recap

We are working on an experimental project called TorchDynamo. TorchDynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. TorchDynamo hooks into the frame evaluation API in CPython to dynamically modify Python bytecode right before it is executed. It rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with an ensemble of different backends and autotuning. It creates this FX Graph through bytecode analysis and is designed to generate smaller graph fragments that can be mixed with Python execution to get the best of both worlds: usability and performance.

If you are new here the TorchDynamo README is a good place to start, you can also catch up on our prior posts:

The TorchDynamo repository has moved to GitHub - facebookresearch/torchdynamo: A Python-level JIT compiler designed to make unmodified PyTorch programs faster. and now has CI workflows.

TorchDynamo Capture Improvements

The biggest change since last time has been work to increase the amount of Python supported to allow more captured ops and bigger graphs. TorchDynamo operators captured in TorchBench increased from 83% to 99.8%. Now, 42 out of 52 models in TorchBench (80%) provide a single whole-program graph.

For cases where users want whole-program graphs, we have added a nopython=True option to TorchDynamo, similar to the one available in numba. This will turn graph breaks into errors and throw an exception referencing the line of code causing the graph break. See the usage example in README.md for more details.

For 2 out of 10 of the remaining programs with graph breaks, the graph breaks can likely be removed with some additional work.

For 8 out of 10 of the remaining programs with graph breaks, the reasons are more fundamental:

  • Conversion to python types (Tensor.item(). Tensor.tolist(), etc). Usage of these functions is common in PyTorch program. Some usage is incidental (logging loss, early stopping, etc), but it is also common for these operations to be used for things like indexing, masking, and bounding box calculations.
  • Usage of non-PyTorch libraries (most commonly numpy). Examples: 1) usage of numpy for indexing; 2) usage of numpy for randomness; 3) copy.deepcopy for finetuning.

These graph breaks are required to respect the semantics of the user program (reordering PyTorch ops with respect to arbitrary code is unsound). These are good examples of why PyTorch programs can’t be represented as static whole-program graph in all cases. AOT Autograd is especially useful in these cases as it allows compilers to move operations between forwards/backwards without whole program graphs.

Dynamic Shapes

For 2 of the above benchmarks (detectron2_maskrcnn/vision_maskrcnn) there are dynamic shapes that lead to lower capture and 200+ graphs due to over-specialization. We must prioritize supporting dynamic shapes in the near future, as this is a clear failure case that could cause issues for users. TorchDynamo currently has a torchdynamo.config.dynamic_shapes=True flag, that will disable shape-specialization and put calls to Tensor.size() as operators in the FX graph, however backends do not currently support these types of FX graphs well. Backend support of dynamic shapes is mainly blocked on the dispatcher, which Nick is working on fixing for both AOT Autograd and Lazy Tensors.

Unfortunately, the problem of dynamic shapes is more complex than one might think. Enabling torchdynamo.config.dynamic_shapes will cause new graph breaks. Many models have code like assert x.shape == (1,2,3), if x.size(1) == 10, math.sqrt(x.shape[-1]), etc. This Python code operating on integer shapes is the defacto way to express many things in PyTorch. With static shapes, TorchDynamo can constant-propagate this stuff away, however, with dynamic shapes it will break the graph. Zach has some interesting ideas around first class dimensions that would make it easier to users to write shape-agnostic code.

My current thinking is a “partially specialized shapes” mode in TorchDynamo. The basic idea would be that all shape start as fully dynamic, but then TorchDynamo would convert a tensor’s shapes to be static when the user called Tensor.size() and passed the result to a non-PyTorch operation. This would allow dynamic shapes most of the time, but still allow bigger graphs when users operate directly on shapes as integers.

I’d also call out that we need better benchmarks that exhibit dynamic shape behavior. Currently, TorchBench is too static. Please submit more dynamic benchmarks to TorchBench!

Data on Capture and Graph Counts

The attached table shows per-benchmark data on operators captured and number of graphs.

  • graph_calls is the count of graphs run by TorchDynamo
  • captured_ops shows how many PyTorch operators were run inside of TorchDynamo graphs
  • total_ops shows how many PyTorch operators were run in the benchmark, including those not captured by TorcDynamo and run eagerly
  • pct_ops is captued_ops / total_ops
  • pct_time is the amount of execution time spent in TorchDynamo graphs, it is similar to pct_ops, but weighted by how long each operator takes to run and taking into account capture overheads.

Previews/Coming Soon

Animesh is exploring a TorchDynamo + AOT Autograd integration to provide training support in TorchDynamo by using AOT Autograd to capture/partition the backwards. Currently, 34 TorchBenchmark models are passing accuracy testing that checks computed gradients from the backwards pass. This effort has uncovered and fixed bugs in nvFuser, FX, AOT Autograd, TorchBench, and TorchDynamo.

Shunting is exploring a TorchDynamo + Lazy Tensors integration. He has a working prototype able to use Lazy Tensors backends in TorchDynamo. Currently, 36 TorchBenchmark models are working correctly for inference.

Stay tuned for full updates once both of these projects are working on all benchmarks.

4 Likes

Thank you for sharing.

Could you please compare cases where fx.symbolic_trace would fail (e.g., when it detects control flow based on an input) vs where TorchDynamo will break a graph?

None are due to control flow?

I’ve seen increasing usage of algorithms that heavily rely on control flow.
E.g. transformers.generation_beam_search.

The two maskrcnn benchmarks have dynamic control flow. I’d love to see more cases of dynamic control flow in TorchBench, I think it would be a good test for TorchDynamo and other frameworks. Contributions welcome here!

There are lots of differences between FX and TorchDynamo. Yes, control flow is one as you mentioned. I’d expect the coverage from TorchDynamo to be higher. Also, TorchDynamo is sound, while FX tracing is unsound.

One example of something TorchDynamo works on, but FX doesn’t would be:

if x.dtype == torch.float32:
   ...

(TorchDynamo will add a guard on x.dtype, and recompile if the dtype changes).

You could also come up with more wacky examples like:

global_list = []

def foo(a, b):
    x = a + b
    def bar(d):
       return x - d + 2
    global_list.append(bar)
    x += 1

with torchdynamo.optimize(my_compiler):
    foo(torch.rand(10), torch.rand(10))
    print(global_list[0](torch.rand(10)))

For this TorchDynamo generates two graphs (one with a+b+1 and another with x-d+2)

@jansel could you please clarify the types of code that will lead to graph breaks for TorchDynamo?
I think I understand the conditions where fx.symbolic_trace will fail, but I don’t have any mental model of when TorchDynamo will break the graph, vs just put in a guard.

If this is already documented somewhere feel free to just point me to it, and I apologize for missing it.
Thanks!

The most fundamental ones are:

There is also a long tail Python things that aren’t supported today, but could or will be supported in the future. For more on those I’d encorage you to look at the github issues. Specifically the Python Coverage milestone has of bunch (but not all) of them.