PyTorch/XLA 2.2 Release Dev Update

Hey I am here to give another dev update for the PyTorch/XLA 2.2 release. It is a relatively short update compared to the 2.1 because it only covers 3 months of development and it overlaps with a bunch of vacations at the year end. We shared a pretty detailed roadmap for many features in my 2.1 update. We still plan to work on most of those but I don’t want to copy the same things over. I would focus on the features that I think would interest the community the most in this update.

Again this dev update will overlap with our release note. The idea for this update is for me to share why we are making some technical decisions and share our plan for the next few months.


SPMD continues to be the biggest development update for the 2.2 since we were wrapping up a bunch of features for our big LLM push last year.

DTensor Integration

We have successfully integrated with PyTorch’s DTensor API and we plan to shift most of our use cases to use DTensor instead. The idea was user only need to specify xla when creating mesh, a sample code snippet will look something like

import torch

from torch.distributed import DeviceMesh, Shard, distribute\_tensor

# distribute_tensor now works with `xla` backend using PyTorch/XLA SPMD.

mesh = DeviceMesh("xla", list(range(world_size)))

big_tensor = torch.randn(100000, 88)

my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)])


The naming can be a bit confusing but let me explain. Currently PyTorch/XLA has a FSDP wrapper(api wise it is very similar to FSDP from PyTorch upstream), this wrapper will continue to be there. The idea of the FSDP v2 is simple, FSDP is a very user friendly way of sharding large scale models and it actually turns out to be very compute efficient in many cases. We decided to implement a very similar wrapper using the SPMD technology. Advantages(compared to the existing FSDP wrapper) of this approach include faster graph compilation time(since SPMD only needs to compile 1 graph per device) and more efficient collective(current FSDP can not use asyncours all-gather on TPUs). You can find more details in this doc.

Sharding Device Visualization

We also added two APIs to visualize the device sharding with sharded tensor or sharding string. They could be very handy when you want to figure out what’s the most efficient device placement. You can find more details in this doc.

(WIP) Auto Sharding

We have been working on bringing the XLA’s auto-sharding to the PyTorch/XLA. This feature will allow XLA to pick a sharding annotation without any user sharding hints. XLA auto-sharding service is based on a published research work, Alpa. Our preliminary benchmark on TPU looks very promising, for more detail please take a look at this RFC. Current plan is to release this feature as experimental in the 2.3 release.


Persistant Compilation Cache

Long compilation time is one of the most noticeable usability issues for PyTorch/XLA. SPMD partially addressed the issue by reducing the number of graphs, but that was not enough. Persistent compilation cache will serialize the compiled program and save it into the local disk which. Saved computation can be loaded the next time the model is run. Note that if you update the model code and result in a different HLO you still need to recompile. For more detail please take a look at this doc.

Compilation/Execution analysis

One common complaint of PyTorch/XLA is that users don’t understand what’s actually going on at PyTorch/XLA level. This time we enhance the PT_XLA_DEBUG=1 flag and let it dump an analysis for every compilation and execution. Some sample analysis looks like

Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis:   user mark_step
Execution Analysis: Graph Info:
Execution Analysis:   Graph Hash: 537d4b0264b029688281412214d252e9
Execution Analysis:   Number of Graph Inputs: 588
Execution Analysis:   Number of Graph Outputs: 320
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis:   mark_step (/workspaces/dk2/pytorch/xla/torch_xla/core/
Execution Analysis:   broadcast_master_param (/workspaces/dk2/pytorch/xla/torch_xla/core/

Async profile captures

Currently the user has to manually run the script to capture the profile on a separate terminal while the model is running which is not very scalable. trace_detached was added to provide a way to automatically profile the model.


(WIP)Single Step Graph

When tracing a model dynamo currently generates a single forward graph and generates the backward graph using the aot-autograd. Having a single step in multiple graphs will bring some performance and memory drawbacks compared to a single graph. We worked with the PyTorch team to propose using the CompiledAutograd to generate the backward graph and inlinging it using dynamo. The RFC can be found here and the POC here. Current plan is to break the POC down and merge them gradually. We expect to merge some of the code changes into 2.3 release but the whole feature will likely take longer.


Multi host training

We officially support the multi-node training on GPU in the 2.2 release through PJRT runtime and torchrun.


The focus of the SPMD work has been on TPU but we want to bring more attention to the GPU SPMD for 2.3 release. The plan is to unblock the functionality on 2.3 releases and perform more detailed speed analysis/optimization on later releases.

(WIP) Benchmark

We are working on creating a daily torchbench runs on GPU(and TPU) to understand the XLA:GPU and PyTorch/XLA:GPU performance against other compiler backends like TorchInducotr. Throughout this process we fixed a bunch of issues in out benchmarking script and models. The plan is to set up the automatic benchmark, collect data and publish it on a public dashboard by 2.3 .

Core Aten Ops

As of the 2.2 release, PyTorch/XLA currently supports ~80% of the core aten ops. The goal is to increase the percentage to near 100% (except a couple dynamic ops) by 2.3 release. This will allow a better performance when training/inference a PyTorch/XLA model and significantly increase the model coverage for the torch.export.


Distributed checkpointing is supported using torch.distributed.checkpoint. To enable more advanced checkpointing features, the CheckpointManager class is introduced. Features enabled include per-step checkpoint management, asynchronous checkpointing, and autocheckpointing on preemption. See the docs for a more complete description.

Async Checkpointing

One key feature of the CheckpointManager API is async checkpointing, which allows training to continue while the checkpoint is written to persistent storage in the background. This is accessible through the CheckpointManger.save_async method.

Auto Checkpointing on Preemption

When using a Cloud TPU with autocheckpoint enabled, the CheckpointManger class will take a checkpoint when a preemption has been detected. The checkpoint will be identified by the step at which the preemption was detected.