Recap
We are working on an experimental project called TorchDynamo. TorchDynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. TorchDynamo hooks into the frame evaluation API in CPython to dynamically modify Python bytecode right before it is executed. It rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with an ensemble of different backends and autotuning. It creates this FX Graph through bytecode analysis and is designed to generate smaller graph fragments that can be mixed with Python execution to get the best of both worlds: usability and performance.
If you are new here the TorchDynamo README is a good place to start, you can also catch up on our prior posts:
- Update 1: An Experiment in Dynamic Python Bytecode Transformation
- Update 2: 1.48x Geomean Speedup on TorchBench CPU Inference
- Update 3: GPU Inference Edition
- Update 4: Lazy Tensors & nvFuser Experiments
The TorchDynamo repository has moved to GitHub - facebookresearch/torchdynamo: A Python-level JIT compiler designed to make unmodified PyTorch programs faster. and now has CI workflows.
TorchDynamo Capture Improvements
The biggest change since last time has been work to increase the amount of Python supported to allow more captured ops and bigger graphs. TorchDynamo operators captured in TorchBench increased from 83% to 99.8%. Now, 42 out of 52 models in TorchBench (80%) provide a single whole-program graph.
For cases where users want whole-program graphs, we have added a nopython=True
option to TorchDynamo, similar to the one available in numba. This will turn graph breaks into errors and throw an exception referencing the line of code causing the graph break. See the usage example in README.md for more details.
For 2 out of 10 of the remaining programs with graph breaks, the graph breaks can likely be removed with some additional work.
For 8 out of 10 of the remaining programs with graph breaks, the reasons are more fundamental:
- Conversion to python types (Tensor.item(). Tensor.tolist(), etc). Usage of these functions is common in PyTorch program. Some usage is incidental (logging loss, early stopping, etc), but it is also common for these operations to be used for things like indexing, masking, and bounding box calculations.
- Usage of non-PyTorch libraries (most commonly numpy). Examples: 1) usage of numpy for indexing; 2) usage of numpy for randomness; 3) copy.deepcopy for finetuning.
These graph breaks are required to respect the semantics of the user program (reordering PyTorch ops with respect to arbitrary code is unsound). These are good examples of why PyTorch programs can’t be represented as static whole-program graph in all cases. AOT Autograd is especially useful in these cases as it allows compilers to move operations between forwards/backwards without whole program graphs.
Dynamic Shapes
For 2 of the above benchmarks (detectron2_maskrcnn/vision_maskrcnn) there are dynamic shapes that lead to lower capture and 200+ graphs due to over-specialization. We must prioritize supporting dynamic shapes in the near future, as this is a clear failure case that could cause issues for users. TorchDynamo currently has a torchdynamo.config.dynamic_shapes=True
flag, that will disable shape-specialization and put calls to Tensor.size() as operators in the FX graph, however backends do not currently support these types of FX graphs well. Backend support of dynamic shapes is mainly blocked on the dispatcher, which Nick is working on fixing for both AOT Autograd and Lazy Tensors.
Unfortunately, the problem of dynamic shapes is more complex than one might think. Enabling torchdynamo.config.dynamic_shapes
will cause new graph breaks. Many models have code like assert x.shape == (1,2,3)
, if x.size(1) == 10
, math.sqrt(x.shape[-1])
, etc. This Python code operating on integer shapes is the defacto way to express many things in PyTorch. With static shapes, TorchDynamo can constant-propagate this stuff away, however, with dynamic shapes it will break the graph. Zach has some interesting ideas around first class dimensions that would make it easier to users to write shape-agnostic code.
My current thinking is a “partially specialized shapes” mode in TorchDynamo. The basic idea would be that all shape start as fully dynamic, but then TorchDynamo would convert a tensor’s shapes to be static when the user called Tensor.size() and passed the result to a non-PyTorch operation. This would allow dynamic shapes most of the time, but still allow bigger graphs when users operate directly on shapes as integers.
I’d also call out that we need better benchmarks that exhibit dynamic shape behavior. Currently, TorchBench is too static. Please submit more dynamic benchmarks to TorchBench!
Data on Capture and Graph Counts
The attached table shows per-benchmark data on operators captured and number of graphs.
- graph_calls is the count of graphs run by TorchDynamo
- captured_ops shows how many PyTorch operators were run inside of TorchDynamo graphs
- total_ops shows how many PyTorch operators were run in the benchmark, including those not captured by TorcDynamo and run eagerly
- pct_ops is
captued_ops / total_ops
- pct_time is the amount of execution time spent in TorchDynamo graphs, it is similar to pct_ops, but weighted by how long each operator takes to run and taking into account capture overheads.
Previews/Coming Soon
Animesh is exploring a TorchDynamo + AOT Autograd integration to provide training support in TorchDynamo by using AOT Autograd to capture/partition the backwards. Currently, 34 TorchBenchmark models are passing accuracy testing that checks computed gradients from the backwards pass. This effort has uncovered and fixed bugs in nvFuser, FX, AOT Autograd, TorchBench, and TorchDynamo.
Shunting is exploring a TorchDynamo + Lazy Tensors integration. He has a working prototype able to use Lazy Tensors backends in TorchDynamo. Currently, 36 TorchBenchmark models are working correctly for inference.
Stay tuned for full updates once both of these projects are working on all benchmarks.