with Will Constable, Jason Ansel
with Jack Cao from Google PyTorch/XLA team
TLDR: We’ve built a prototype bridge to integrate dynamo with PyTorch/XLA. We benchmarked the bridge on a subset of 10 pytorch/benchmark models. For inference, we verified the numerical correctness and achieved 1.5x geomean speedup on GPU and 1.8x geomean speedup on TPU compared to PyTorch/XLA baseline. For training, we verified the numerical correctness, summarized the overall perf result (overall slightly better than neutral) on GPU and analyzed the cause of the slowdown on some models.
There are 2 major motivations to integrate dynamo with PyTorch/XLA. For one thing, this enables users to run models on TPU thru dynamo; for another thing, we can leverage dynamo’s guard system to skip the PyTorch/XLA tracing overhead.
Dynamo and PyTorch/XLA leverage different mechanisms to guarantee sound graph capture (check The nuances of PyTorch Graph Capture for a comprehensive comparison of different graph capture techniques). Dynamo leverages its guard system to record the conditions that need to be met to apply a captured graph. The guard system is very lightweight in its cost. On the other hand, PyTorch/XLA trace the graph for every run and maintain a hash for the graph. A previously captured and optimized graph will be reused if the hash hits the cache. The tracing overhead varies depending on the model but it’s usually quite high.
Naively integrating dynamo and PyTorch/XLA together, we will have both dynamo’s guard system and PyTorch/XLA’s tracing system to guarantee sound graph capture. That’s redundant. Instead we want to let PyTorch/XLA trace the graph once at compile time but skip the tracing step completely at runtime. This potentially can bring speed up at runtime because of reducing tracing overhead to 0. The soundness of the graph will be guaranteed by the lighter weight guard system in dynamo.
Here are the previous dynamo updates:
- Update 1: An Experiment in Dynamic Python Bytecode Transformation
- Update 2: 1.48x Geomean Speedup on TorchBench CPU Inference
- Update 3: GPU Inference Edition
- Update 4: Lazy Tensors & nvFuser Experiments
- Update 5: Improved Capture and Bigger Graphs
- Update 6: Training support with AOTAutograd
- Update 7: Inference with FX2TRT
- Update 8: TorchDynamo passed correctness check on 7k+ github models
- Update 9: Making DPP Work with TorchDynamo
Correctness Verification
People may wonder if we can still guarantee correctness by skipping PyTorch/XLA tracing completely. To verify the correctness, we run multiple runs for baseline and test. Each run uses different inputs. For inference, we check the numerical correctness for the prediction; for training, we check the numerical correctness for prediction, loss and gradients etc. We make sure all these correctness checks are passed.
Testing Environment
All the tests for GPU are done inside the PyTorch/XLA docker container ( Google Cloud console ). The relevant spec for the host system:
- CPU: AMD Ryzen Threadripper PRO 3975WX 32-Cores
- GPU: NVIDIA GeForce RTX 3080 with 10GB high-bandwidth memory
For the tests on TPU, we are using TPU v4 which is the same generation as GPU A100. The host has an AMD EPYC 7B12 CPU with 120 cores.
Inference Perf Result on GPU
We test on a subset of pytorch/benchmark models that don’t trigger PyTorch/XLA fallbacks. A PyTorch/XLA fallback will cause graph break that dynamo is unaware of. This brings some challenges to construct the XLA graph that is equivalent to the FX graph dynamo provides for optimizing. Will Constable brings up an idea to let the bridge tell dynamo all the potential XLA fallbacks and dynamo will break the Fx graph properly and only provide potentially smaller FX graphs that will not trigger PyTorch/XLA fallback to the bridge for optimizing. The handling of PyTorch/XLA fallback will be done as a follow up.
We use PyTorch/XLA as the baseline. Dynamo’s benchmarking script uses PyTorch eager as the baseline by default. But we think it’s more convenient to use the same framework for baseline and test to avoid numerical differences introduced by different frameworks.
To make the test complete, we built 2 bridges to integrate dynamo with PyTorch/XLA. ‘trace_once’ is the bridge that traces once at compile time and skips tracing at runtime. ‘trace_everytime’ is the trivial bridge that still does PyTorch/XLA trace for every run. Without explicitly mentioning, a bridge refers to the ‘trace_once’ bridge in this note.
Here are takeaways from the testing result:
- trace_everytime bridge’s perf is neutral for each model. There is no much overhead introduced by dynamo
- trace_once has 1.5x geomean speedup. Skipping tracing indeed speeds up inference significantly. For transformer models (BERT_pytorch, timm_visiion_transformer) the speedup is even 3.0x and 2.1x respectively. This shows PyTorch/XLA incurs higher tracing overhead for transformer models.
Here is a trace for resnet50 to illustrate the perf win (note enabling profiling may cause inaccurate time measurement, just get a high level understanding of the time cost):
We can see the baseline has significant tracing overhead. And the test (the bridge) gets rid of the tracing overhead while having neutral computation time. For alexnet, tracing overhead is much less compared to computation, thus we don’t see significant savings by removing tracing overhead.
Inference Perf Result on TPU
For the result on TPU, we see the trace_once bridge has an even larger geomean speedup: 1.8x . The reason is the TPU devices are more powerful than the GPU we use for testing. With a powerful device, we have faster computation. Avoiding tracing thus has a bigger impact for powerful devices. E.g., for resnet50 the computation time on TPU and GPU are 6.7ms v.s. 20.7 ms respectively. Technically, we also have more powerful CPUs on the TPU environment, so the tracing overhead should also be smaller but the impact is not as big as the reduction of computation time.
The trace_everythime bridge is also neutral as on GPU which shows dynamo overhead is very low.
Training on GPU
We rely on AOTAutograd to train models in dynamo. AOTAutograd will call the bridge twice: once with the forward graph and once with the backward graph. The optimized forward/backward graph will be wrapped into a torch.autograd.Function.
Training is much trickier than inference for the integration:
- In the training case, PyTorch/XLA (baseline) only generates a single combined graph for fwd/bwd/optimizer while the trace_once bridge will generate multiple smaller graphs: one for forward, one for backward and a couple for the optimizer. XLA favors larger graphs to do more optimizations.
- In the training case, tracing overhead can be overlapped with computation. Tracing overhead is not as a big deal for training as for inference. After all, training cares more about throughput while inference cares more about latency.
- In the training case, people can increase batch size to ‘mitigate’ the tracing overhead. Increasing batch size does not change tracing overhead, thus it shows like the tracing overhead ‘per example’ reduces.
Even though, we still want to explore integrating dynamo with PyTorch/XLA for training since
- we can provide consistent UX for inference and for training
- we can still have perf gains for models with high tracing overhead
Here are the results for training:
Some notes about the result:
-
We have to reduce default batch size used by dynamo for the following models to avoid OOM in the testing environment
- resnet50: 32 → 16
- mobilenet_v2: 96 → 16
- vgg16: 64 → 8
- BERT_pytorch: 16 → 2
-
There is a device synchronization after each training iteration
Some takeaways from the result
- Overall trace_once is slightly better than neutral (1.05x speedup) and trace_evertime is slower than baseline (0.96x speedup). Note that our baseline is PyTorch/XLA rather than eager. It’s already a very strong baseline.
- For models incurring high tracing overhead, we still see 1.7x (Bert_pytorch) and 1.4x (timm_vision_transformer) speedup
Dive into the perf number for the resnet50 on GPU
We dive into the perf number for the resnet50 model. We think the perf number is mainly determined by the interaction of the 3 factors: reduce_tracing_overhead, penalty_for_graph_breaks and penalty_from_aot_autograd .
The bridge incurs penalty_for_graph_breaks since we generate multiple graphs for training. A typical training loop has the following steps: zero_grad → fwd → bwd → optimizer . In our bridge, each of these steps will result in an XLA graph. To understand the impact of graph break, we manually inject a graph break between fwd and bwd graph in the baseline. The XLA ExecuteTime (represents the execution time on the TPU/GPU device) metric increases from 41.2ms to 44.4 ms by 7.8%. We think the graph break causes loss of some fusion opportunities which can cause the perf loss. Let’s also compare the XLA ExecuteTime between baseline (a single graph) and test. There is a 11.9% increase for XLA ExecuteTime metric ( 41.2 ms v.s. 46.1 ms ).
AOTAutograd also causes some perf loss since we see about 4% slow down in trace_everytime bridge as well. While the loss in trace_everytime may less likely be from dynamo since we see neutral results in inference for trace_everytime. One cause of perf loss in AOTAutograd is we need to apply updates to mutated inputs (pytorch/aot_autograd.py at master · pytorch/pytorch · GitHub ). It costs about 2ms for resnet50.
Whether or not we see a perf gain depends on whether: reduce_tracing_overhead > penalty_for_graph_breaks + penalty_from_aot_autograd + other_minor_factors_we_have_not_considered
Here is one idea to bring back the perf loss. We rely on dynamo to capture the guarded graph for forward pass. Then use PyTorch/XLA to construct the whole forward/backward/optimizer training graph. In the bridge we reuse the whole training graph directly. This proposal also avoids the overhead introduced by graph break between forward, backward and optimizer. But it may be hard to apply this improvement when there is a graph break in the forward pass. We’ll explore this as a follow up.
Deep Dive into Some of the Tech Details
Match FX Graph Input/Output with XLA Graph Input/Output
Here is how the bridge works in general. dynamo passes an FX graph (wrapped in a fx.GraphModule) to the bridge. The bridge needs to return an optimized_mod which does the same thing as the FX graph but (hopefully) faster. The input/output to the optimized_mod will match the input/placeholder nodes and outputs nodes of the Fx graph.
In optimized_mod we need to call the saved XLA graph. But the input/output of the XLA graph can be very different to the Fx graph input/output.
The input to a XLA graph are all the DeviceData nodes collected in trace order. A DeviceData node is defined as a leaf node representing data rather than computation in PyTorch/XLA IR. Some of these DeviceData nodes represent Fx graph inputs but some may just represent const tensor/scalar implicitly defined by some certain op. E.g. the alpha parameter in the add op (pytorch/native_functions.yaml at master · pytorch/pytorch · GitHub ). The FX inputs may spread in different order in the XLA graph input because the latter is defined by the tracing order. We build a GraphInputMatcher class to construct XLA graph inputs based on the FX graph inputs and the const tensor/scalar collected during tracing time.
XLA may return a different list of outputs compared to FX graph output. e.g., XLA does not return an output if it’s also an input unless the tensor is inplace updated. XLA will also deduplicate the same outputs. All these cases indeed happen for FX graphs generated by AOTAutograd. Check verify the number of outputs of xla graph by shunting314 · Pull Request #89536 · pytorch/pytorch · GitHub for details.
Random Seed
Random seed is also represented as a DeviceData node. But it’s special since we need to pass in the proper seed value maintained by XLA at runtime to the XLA graph to guarantee the same randomness between test and baseline. Improper handling of random seed causes ops like Dropout to produce different results for test and baseline.
Tracing on XLA
In an earlier version of the implementation, we let dynamo trace the model on the CPU. Before and after running each optimized_mod, the bridge needs to move input/output between CPU/XLA. This causes performance loss. It’s an even severe problem when applying AOTAutograd since the fwd/bwd graph generated by AOTAutograd used to have a lot of tensors as inputs and outputs.
Later on we changed to let dynamo trace the model on XLA directly. This improves the geomean speedup from 1.4x to 1.5x on GPU for inference. Also the UX is more similar to how people usually enable PyTorch/XLA.
Follow Ups
- Solve PyTorch/XLA fallback problem in the bridge for dynamo and PyTorch/XLA integration: [Feature Request][XLA] Support fallback for the dynamo-xla bridge · Issue #1823 · pytorch/torchdynamo · GitHub
- Explore ways to reduce the perf loss caused by graph break and AOTAutograd in training
- Do more tests
- with larger batch size or less frequent device sync for training
- on more models