Hey I am here to give a late update for the PyTorch/XLA 2.3 release. Similar to my previous update, you can check our release note for detailed updates. I am going to highlight some of the new features and share how I think about them.
SPMD Auto-sharding
We launched the experimental support for single host TPU auto-sharding in the 2.3 release. It can be enabled via torch-xla’s custom API
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
Or you can enable it with the distributed module API
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
# Currently, model should be loaded to xla device via distribute_module.
model = MyModule() # nn.module
sharded_model = distribute_module(model, device_mesh, auto_policy)
You can find more details about the auto-sharding in here. IMO auto-sharding is a great way to experience the capabilities of the GSPMD. We benchmarked the popular transformer models like GPT2, Llama2, and Gemma, auto-sharding can achieve ~ 95% of MFU compared to the carefully optimized alternatives on Cloud TPU v4-8. Auto-sharding does incur a higher compile time, hence we strongly recommend user to turn on the persistent compilation cache following this guide.
My 2 cents about the auto-sharding is that it is a cool feature that can give us a performance baseline and inspire us on how to shard the model. However long compilation time makes it a bit challenging to use if you are constantly changing the model code. For most of the transformer-like models my recommendation will be first try the FSDPv2(I am going to talk about this in the following section) and see if it works.
FSDPv2 (FSDP via SPMD)
We are still trying to figure out whether we can come up with a more fancy name but currently we are settled on FSDPv2. I already talked about this in my laste update. FSDPv2 has become our standard way of training large language models. We integrated FSDPv2 with HF recently and have achieved 60%+ MFUs on llama3 and Geema on Cloud TPU v5p. I think this should be the first thing users try when they want to scale up their model.
Pallas
Pallas is a custom kernel language that supports both TPU and GPU. In the 2.3 release Pytorch/XLA integrated with Pallas on TPU and provided support using Pallas based FlashAttention kernel for the model forward. We also added the support for the torch.compile
with the flash attention
from torch_xla.experimental.custom_kernel import flash_attention
output = flash_attention(q, k, v)
In the nightly we also added the support for the backward along with a couple other popular kernels like paged_attention
and gmm
. With Pallas integration PyTorch/XLA will be able to quickly experiment with the most advanced features. PyTorch/XLA will keep adding the popular Pallas kernels and examples to the repo so users can directly use them. For more detail please take a look at this user guide.
Export
From the 2.3 release PyTorch/XLA will consider export
to be a stable feature. In this release we added support for all of the core aten ops that’s used by the export. We also enabled the dynamism in the export and fixed a bunch of bugs. Please give this feature a try if you are interested in exporting a pytorch model to StableHLO.
Torch Dynamo
In the 2.3 release PyTorch/XLA added the custom op dynamo_mark_sharding
which can be used to perform the activation sharding in a torch.compile
region. This is the first step to make torch.compile
+ GSPMD
to be the recommended way of doing the model inference using PyTorch/XLA. We plan to provide more examples and performance benchmarks in the 2.4 release.
GPU
PyTorch/XLA added the official support for running SPMD on either a single node GPU or multinode GPU. You can find the instructions in this doc. We were able to get ~50% MFU on Llama2 using a single GPU node with 8 A100 GPU. PyTorch/XLA:GPU is also now a plugin that you can install with
pip torch~=2.3.0 torch_xla~=2.3.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.3.0-py3-none-any.whl
While loop
A frequent criticism of PyTorch/XLA is the extended compilation duration. One of the root causes of this prolonged compilation process is the unrolling of all Python loops, leading to a substantial HLO. In the 2.3 release PyTorch/XLA added the experimental support for the torch._higher_order_ops.while_loop
. Currently it can only handle very basic torch operations, but we aim to add the nn module support in the upcoming release. The goal is to have the while_loop
to wrap around a module(like a DecoerLayer in Llama) and produce a compact HLO. For more detail please take a look at this doc.