PyTorch/XLA 2.3 dev update

Hey I am here to give a late update for the PyTorch/XLA 2.3 release. Similar to my previous update, you can check our release note for detailed updates. I am going to highlight some of the new features and share how I think about them.

SPMD Auto-sharding

We launched the experimental support for single host TPU auto-sharding in the 2.3 release. It can be enabled via torch-xla’s custom API

import torch_xla.runtime as xr

Or you can enable it with the distributed module API

import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

# Currently, model should be loaded to xla device via distribute_module.
model = MyModule()  # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)

You can find more details about the auto-sharding in here. IMO auto-sharding is a great way to experience the capabilities of the GSPMD. We benchmarked the popular transformer models like GPT2, Llama2, and Gemma, auto-sharding can achieve ~ 95% of MFU compared to the carefully optimized alternatives on Cloud TPU v4-8. Auto-sharding does incur a higher compile time, hence we strongly recommend user to turn on the persistent compilation cache following this guide.

My 2 cents about the auto-sharding is that it is a cool feature that can give us a performance baseline and inspire us on how to shard the model. However long compilation time makes it a bit challenging to use if you are constantly changing the model code. For most of the transformer-like models my recommendation will be first try the FSDPv2(I am going to talk about this in the following section) and see if it works.


We are still trying to figure out whether we can come up with a more fancy name but currently we are settled on FSDPv2. I already talked about this in my laste update. FSDPv2 has become our standard way of training large language models. We integrated FSDPv2 with HF recently and have achieved 60%+ MFUs on llama3 and Geema on Cloud TPU v5p. I think this should be the first thing users try when they want to scale up their model.


Pallas is a custom kernel language that supports both TPU and GPU. In the 2.3 release Pytorch/XLA integrated with Pallas on TPU and provided support using Pallas based FlashAttention kernel for the model forward. We also added the support for the torch.compile with the flash attention

from torch_xla.experimental.custom_kernel import flash_attention

output = flash_attention(q, k, v)

In the nightly we also added the support for the backward along with a couple other popular kernels like paged_attention and gmm. With Pallas integration PyTorch/XLA will be able to quickly experiment with the most advanced features. PyTorch/XLA will keep adding the popular Pallas kernels and examples to the repo so users can directly use them. For more detail please take a look at this user guide.


From the 2.3 release PyTorch/XLA will consider export to be a stable feature. In this release we added support for all of the core aten ops that’s used by the export. We also enabled the dynamism in the export and fixed a bunch of bugs. Please give this feature a try if you are interested in exporting a pytorch model to StableHLO.

Torch Dynamo

In the 2.3 release PyTorch/XLA added the custom op dynamo_mark_sharding which can be used to perform the activation sharding in a torch.compile region. This is the first step to make torch.compile + GSPMD to be the recommended way of doing the model inference using PyTorch/XLA. We plan to provide more examples and performance benchmarks in the 2.4 release.


PyTorch/XLA added the official support for running SPMD on either a single node GPU or multinode GPU. You can find the instructions in this doc. We were able to get ~50% MFU on Llama2 using a single GPU node with 8 A100 GPU. PyTorch/XLA:GPU is also now a plugin that you can install with

pip torch~=2.3.0 torch_xla~=2.3.0

While loop

A frequent criticism of PyTorch/XLA is the extended compilation duration. One of the root causes of this prolonged compilation process is the unrolling of all Python loops, leading to a substantial HLO. In the 2.3 release PyTorch/XLA added the experimental support for the torch._higher_order_ops.while_loop. Currently it can only handle very basic torch operations, but we aim to add the nn module support in the upcoming release. The goal is to have the while_loop to wrap around a module(like a DecoerLayer in Llama) and produce a compact HLO. For more detail please take a look at this doc.