TorchDynamo Update 10: Integrating with PyTorch/XLA for Inference and Training

with Will Constable, Jason Ansel
with Jack Cao from Google PyTorch/XLA team

TLDR: We’ve built a prototype bridge to integrate dynamo with PyTorch/XLA. We benchmarked the bridge on a subset of 10 pytorch/benchmark models. For inference, we verified the numerical correctness and achieved 1.5x geomean speedup on GPU and 1.8x geomean speedup on TPU compared to PyTorch/XLA baseline. For training, we verified the numerical correctness, summarized the overall perf result (overall slightly better than neutral) on GPU and analyzed the cause of the slowdown on some models.

There are 2 major motivations to integrate dynamo with PyTorch/XLA. For one thing, this enables users to run models on TPU thru dynamo; for another thing, we can leverage dynamo’s guard system to skip the PyTorch/XLA tracing overhead.

Dynamo and PyTorch/XLA leverage different mechanisms to guarantee sound graph capture (check The nuances of PyTorch Graph Capture for a comprehensive comparison of different graph capture techniques). Dynamo leverages its guard system to record the conditions that need to be met to apply a captured graph. The guard system is very lightweight in its cost. On the other hand, PyTorch/XLA trace the graph for every run and maintain a hash for the graph. A previously captured and optimized graph will be reused if the hash hits the cache. The tracing overhead varies depending on the model but it’s usually quite high.

Naively integrating dynamo and PyTorch/XLA together, we will have both dynamo’s guard system and PyTorch/XLA’s tracing system to guarantee sound graph capture. That’s redundant. Instead we want to let PyTorch/XLA trace the graph once at compile time but skip the tracing step completely at runtime. This potentially can bring speed up at runtime because of reducing tracing overhead to 0. The soundness of the graph will be guaranteed by the lighter weight guard system in dynamo.

Here are the previous dynamo updates:

Correctness Verification

People may wonder if we can still guarantee correctness by skipping PyTorch/XLA tracing completely. To verify the correctness, we run multiple runs for baseline and test. Each run uses different inputs. For inference, we check the numerical correctness for the prediction; for training, we check the numerical correctness for prediction, loss and gradients etc. We make sure all these correctness checks are passed.

Testing Environment

All the tests for GPU are done inside the PyTorch/XLA docker container ( Google Cloud console ). The relevant spec for the host system:

  • CPU: AMD Ryzen Threadripper PRO 3975WX 32-Cores
  • GPU: NVIDIA GeForce RTX 3080 with 10GB high-bandwidth memory

For the tests on TPU, we are using TPU v4 which is the same generation as GPU A100. The host has an AMD EPYC 7B12 CPU with 120 cores.

Inference Perf Result on GPU

We test on a subset of pytorch/benchmark models that don’t trigger PyTorch/XLA fallbacks. A PyTorch/XLA fallback will cause graph break that dynamo is unaware of. This brings some challenges to construct the XLA graph that is equivalent to the FX graph dynamo provides for optimizing. Will Constable brings up an idea to let the bridge tell dynamo all the potential XLA fallbacks and dynamo will break the Fx graph properly and only provide potentially smaller FX graphs that will not trigger PyTorch/XLA fallback to the bridge for optimizing. The handling of PyTorch/XLA fallback will be done as a follow up.

We use PyTorch/XLA as the baseline. Dynamo’s benchmarking script uses PyTorch eager as the baseline by default. But we think it’s more convenient to use the same framework for baseline and test to avoid numerical differences introduced by different frameworks.

To make the test complete, we built 2 bridges to integrate dynamo with PyTorch/XLA. ‘trace_once’ is the bridge that traces once at compile time and skips tracing at runtime. ‘trace_everytime’ is the trivial bridge that still does PyTorch/XLA trace for every run. Without explicitly mentioning, a bridge refers to the ‘trace_once’ bridge in this note.

Here are takeaways from the testing result:

  • trace_everytime bridge’s perf is neutral for each model. There is no much overhead introduced by dynamo
  • trace_once has 1.5x geomean speedup. Skipping tracing indeed speeds up inference significantly. For transformer models (BERT_pytorch, timm_visiion_transformer) the speedup is even 3.0x and 2.1x respectively. This shows PyTorch/XLA incurs higher tracing overhead for transformer models.

Here is a trace for resnet50 to illustrate the perf win (note enabling profiling may cause inaccurate time measurement, just get a high level understanding of the time cost):

We can see the baseline has significant tracing overhead. And the test (the bridge) gets rid of the tracing overhead while having neutral computation time. For alexnet, tracing overhead is much less compared to computation, thus we don’t see significant savings by removing tracing overhead.

Inference Perf Result on TPU

For the result on TPU, we see the trace_once bridge has an even larger geomean speedup: 1.8x . The reason is the TPU devices are more powerful than the GPU we use for testing. With a powerful device, we have faster computation. Avoiding tracing thus has a bigger impact for powerful devices. E.g., for resnet50 the computation time on TPU and GPU are 6.7ms v.s. 20.7 ms respectively. Technically, we also have more powerful CPUs on the TPU environment, so the tracing overhead should also be smaller but the impact is not as big as the reduction of computation time.

The trace_everythime bridge is also neutral as on GPU which shows dynamo overhead is very low.

Training on GPU

We rely on AOTAutograd to train models in dynamo. AOTAutograd will call the bridge twice: once with the forward graph and once with the backward graph. The optimized forward/backward graph will be wrapped into a torch.autograd.Function.

Training is much trickier than inference for the integration:

  1. In the training case, PyTorch/XLA (baseline) only generates a single combined graph for fwd/bwd/optimizer while the trace_once bridge will generate multiple smaller graphs: one for forward, one for backward and a couple for the optimizer. XLA favors larger graphs to do more optimizations.
  2. In the training case, tracing overhead can be overlapped with computation. Tracing overhead is not as a big deal for training as for inference. After all, training cares more about throughput while inference cares more about latency.
  3. In the training case, people can increase batch size to ‘mitigate’ the tracing overhead. Increasing batch size does not change tracing overhead, thus it shows like the tracing overhead ‘per example’ reduces.

Even though, we still want to explore integrating dynamo with PyTorch/XLA for training since

  1. we can provide consistent UX for inference and for training
  2. we can still have perf gains for models with high tracing overhead

Here are the results for training:

Some notes about the result:

  1. We have to reduce default batch size used by dynamo for the following models to avoid OOM in the testing environment

    • resnet50: 32 → 16
    • mobilenet_v2: 96 → 16
    • vgg16: 64 → 8
    • BERT_pytorch: 16 → 2
  2. There is a device synchronization after each training iteration

Some takeaways from the result

  • Overall trace_once is slightly better than neutral (1.05x speedup) and trace_evertime is slower than baseline (0.96x speedup). Note that our baseline is PyTorch/XLA rather than eager. It’s already a very strong baseline.
  • For models incurring high tracing overhead, we still see 1.7x (Bert_pytorch) and 1.4x (timm_vision_transformer) speedup

Dive into the perf number for the resnet50 on GPU

We dive into the perf number for the resnet50 model. We think the perf number is mainly determined by the interaction of the 3 factors: reduce_tracing_overhead, penalty_for_graph_breaks and penalty_from_aot_autograd .

The bridge incurs penalty_for_graph_breaks since we generate multiple graphs for training. A typical training loop has the following steps: zero_grad → fwd → bwd → optimizer . In our bridge, each of these steps will result in an XLA graph. To understand the impact of graph break, we manually inject a graph break between fwd and bwd graph in the baseline. The XLA ExecuteTime (represents the execution time on the TPU/GPU device) metric increases from 41.2ms to 44.4 ms by 7.8%. We think the graph break causes loss of some fusion opportunities which can cause the perf loss. Let’s also compare the XLA ExecuteTime between baseline (a single graph) and test. There is a 11.9% increase for XLA ExecuteTime metric ( 41.2 ms v.s. 46.1 ms ).

AOTAutograd also causes some perf loss since we see about 4% slow down in trace_everytime bridge as well. While the loss in trace_everytime may less likely be from dynamo since we see neutral results in inference for trace_everytime. One cause of perf loss in AOTAutograd is we need to apply updates to mutated inputs (pytorch/ at master · pytorch/pytorch · GitHub ). It costs about 2ms for resnet50.

Whether or not we see a perf gain depends on whether: reduce_tracing_overhead > penalty_for_graph_breaks + penalty_from_aot_autograd + other_minor_factors_we_have_not_considered

Here is one idea to bring back the perf loss. We rely on dynamo to capture the guarded graph for forward pass. Then use PyTorch/XLA to construct the whole forward/backward/optimizer training graph. In the bridge we reuse the whole training graph directly. This proposal also avoids the overhead introduced by graph break between forward, backward and optimizer. But it may be hard to apply this improvement when there is a graph break in the forward pass. We’ll explore this as a follow up.

Deep Dive into Some of the Tech Details

Match FX Graph Input/Output with XLA Graph Input/Output

Here is how the bridge works in general. dynamo passes an FX graph (wrapped in a fx.GraphModule) to the bridge. The bridge needs to return an optimized_mod which does the same thing as the FX graph but (hopefully) faster. The input/output to the optimized_mod will match the input/placeholder nodes and outputs nodes of the Fx graph.

In optimized_mod we need to call the saved XLA graph. But the input/output of the XLA graph can be very different to the Fx graph input/output.

The input to a XLA graph are all the DeviceData nodes collected in trace order. A DeviceData node is defined as a leaf node representing data rather than computation in PyTorch/XLA IR. Some of these DeviceData nodes represent Fx graph inputs but some may just represent const tensor/scalar implicitly defined by some certain op. E.g. the alpha parameter in the add op (pytorch/native_functions.yaml at master · pytorch/pytorch · GitHub ). The FX inputs may spread in different order in the XLA graph input because the latter is defined by the tracing order. We build a GraphInputMatcher class to construct XLA graph inputs based on the FX graph inputs and the const tensor/scalar collected during tracing time.

XLA may return a different list of outputs compared to FX graph output. e.g., XLA does not return an output if it’s also an input unless the tensor is inplace updated. XLA will also deduplicate the same outputs. All these cases indeed happen for FX graphs generated by AOTAutograd. Check verify the number of outputs of xla graph by shunting314 · Pull Request #89536 · pytorch/pytorch · GitHub for details.

Random Seed

Random seed is also represented as a DeviceData node. But it’s special since we need to pass in the proper seed value maintained by XLA at runtime to the XLA graph to guarantee the same randomness between test and baseline. Improper handling of random seed causes ops like Dropout to produce different results for test and baseline.

Tracing on XLA

In an earlier version of the implementation, we let dynamo trace the model on the CPU. Before and after running each optimized_mod, the bridge needs to move input/output between CPU/XLA. This causes performance loss. It’s an even severe problem when applying AOTAutograd since the fwd/bwd graph generated by AOTAutograd used to have a lot of tensors as inputs and outputs.

Later on we changed to let dynamo trace the model on XLA directly. This improves the geomean speedup from 1.4x to 1.5x on GPU for inference. Also the UX is more similar to how people usually enable PyTorch/XLA.

Follow Ups


“A PyTorch/XLA fallback will cause graph break that dynamo is unaware of”

Is it a correctness issue or a performance issue? If it is a performance issue, does that mean that fall back to aten from Dynamo is more efficient than falling back from PTXLA?

More detail in [Feature Request][XLA] Support fallback for the dynamo-xla bridge · Issue #1823 · pytorch/torchdynamo · GitHub, it is an implementation assumption that xla will return a single hash for a given fx graph, but we can’t do that when there is a fallback. Sherlock seems to have something that might help xla to tell upstream about its fallback condition and let dynamo break the graph for xla, but we need to look into a bit more in that.


Improper handling of torchxla fallback would cause both correctness issue and perf issues

  • for correctness, torchxla fallback may cause the graph we saved for future execution not representing the behavior of the original graph we compiled/optimized
  • for performance, we do see that extra graph breaks slow down performance. However for torchxla fallback, it’s not that big issue since baseline should also be slowdown due to the fallback.

As Jack also mentioned, we have some ideas and can follow up on this.


If it’s just to implement dynamo’s xla backend, I think reimplementing the fx-to-xla lower logic using the xla client in jaxlib is a much cleaner and more performance guaranteed way to do it. Because there are many similarities between jax and torch[dynmao].

If is to improve the current graph fragmentation mode support capabilities of LTC + xla. I think that LTC needs to be re-engineered for implementation, previous compatibility with TPU node mode resulted in high overhead for ltc and eager mode data representation and data exchange now. I don’t think this is an inherent problem with the lazy execution itself, but just a result of the engineering implementation.

I think by TPU node mode you mean XRT runtime and its client-GRPC server nature? We are working on deprecating the XRT runtime and replace it with the PJRT. Do you mind sharing what other overhead you have in mind? When dynam works, lazy bridge code should only be trigged by the first time, and the rest of the execution shoud be directly calling the runtime to execute the graph. Lazy code still handles the part that convert the DeviceData back to the at::Tensor representation but I don’t think that is too costly.

We chose to do our first prototype using lazy because it is the most convenient way. For feature like DS, SPMD, FSDP we already need to implement it for lazy, the engineering cost of reimplement everything with dynamo is too high. There is a possibility for xla to provide a native dynamo implemantion without lazy bridge, but that’s more long term. We want to make sure the proof of concept bridge works well enough.

1 Like

What’s the impact from UX interface point of view?
Can user turn on XLA use the one liner torch.compile(backend = “xla”)?


That’s goal, for inference I think we are already pretty close. You can check our simple test at xla/ at master · pytorch/xla · GitHub, where our backend is called torchxla_trace_once . For training it is more complicated since we need to worry about distributed story etc. Eventually we hope user can turn their model to xla by a simple string swap of backend.

1 Like