PyTorch/XLA 2.1 Release Dev Update

It has been a while since I wrote the last dev update for PyTorch/XLA, the original plan was to give this kind of update quarterly but my long vacation earlier this year ruined the schedule a bit. I have been asked by a couple folks about the roadmap of the PyTorch/XLA, going forward I will try to write this dev update at least per release.

The content of this document will overlap with our release notes. The idea for this doc is for me to share why we are making some technical decisions and share our plan for the next few months.




Owner: Yeounoh Chung Jon Bolin Jiewen Tan Jack Cao


Owner: Yeounoh Chung

We introduce the mark_sharding api that can assign tensor a sharding annotation. The idea is users only have to call this api on key tensors and let the compiler handle the actual communication across accelerators. More technical details can be found in this blog post.

Benchmark and Performance Optimization

Owner: Jiewen Tan Jon Bolin Yeounoh Chung

We used Llama2 and GPT2 to do the performance validation for GSPMD on TPU. For Llama2 we used this gihutb fork and achieved 53%+ model FLOPs utilizations on 7B, 13B, and 70B Llama2. There will be a blog post to share more details soon.

Release 2.2 - 2.3 Feature Plan


Owner: Yeounoh Chung Xiongfei Wei

The early experiments are mostly done on TPUs. Our friends at Alibaba have been testing GSPMD on GPUs and achieving promising results. We also plan to increase the coverage and performance validation on GPUs.

Auto Sharding

Owner: Yeounoh Chung

There is tooling that exists in OpenXLA which provides automatic sharding annotation to a HLO graph. We plan to test out this feature and bring it to the PyTorch/XLA.

Distributed Tensor Integration

Owner: Yeounoh Chung Jiewen Tan

Our team has been working closely with Wanchao and Junjie on the Distributed Tensor integration since both projects started at roughly the same time. Right now Yeounoh Chung and Jiewen Tan are working on a POC to integrate with the Distributed Tensor API so users can use a unified API to do tensor parallelism. We believe this is a good integration point since the Distributed Tensor API is high level enough that each backend can handle the actual sharding implementation differently(I believe that for PyTorch native decompositions, they happen at the framework level, and for XLA collective ops decompositions, they happen at the compiler level).

There is one major difference between XLA GSPMD and PyTorch Native DTensor which is the number of training processes. For PyTorch/XLA GSPMD we only expect to spawn one process per host and Native DTensor will spawn one process per device. From XLA’s perspective using only one process per device will effectively reduce the number of compilation we need to perform per host. XLA compilation can be expensive. XLA GSPMD expects a single process per host to execute on all accessible devices… We are aware that this divergence might cause some confusion, we will discuss this topic further with Wanchao.

SPMD Debugging Tools

Owner: Manfei Bai Yeounoh Chung

SPMD computation can be difficult reason about, especially the sharding annotations and the actual shard whereabout. We plan to add a debugging package to help debugging SPMD workloads.



LLM Toolchain

Owner: Jiewen Tan, Jon Bolin, Mohit Khatwani, Han Qi, Milad Mohammadi

Distributed Checkpointing

Owner: Jon Bolin

Jon Bolin has been working with Rodrigo on distributed checkpointing. The plan is to integrate with the upstream checkpoint api. We already have an early POC working.

Release 2.2 - 2.3 Feature Plan

The focus will be expanding the support to provide a complete toolchain to make LLM UX frictionless from training to inference.

HuggingFace/Lightning SPMD Integration

Owner: Jiewen Tan, Jon Bolin

SPMD is so far the SOTA to achieve the best in-class distributed training performance and yet maintain an ease-to-use UX. We are planning to upstream our Llama 2 integration to HF and generalize it for most dense decoder-only transformers. In parallel, we are also working with Lightning as well.

Deterministic data loader

Owner: Mohit Khatwani

Besides checkpointing, the ability to deterministically re-enter the last data loader is also crucial to improve the whole training system goodput. Right now, we are exploring the integration with torchdata.DataLoader2 or Grain.

TPU orchestration

Owner: Allen Wang, Jon Bolin

Distributing training jobs among thousands of chips is a different monster than within a single host. We are aiming to provide tpu orchestration managers that highly intergrates with ray and GKE to help users from writing scripts that deal with barebone ssh. One step further, these helpers here can be expanded to deploy on mutli-host GPU training as well.

Profiling and debuggability

Owner: Jiewen Tan

Profiling and debugging with thousands of chips is also drastically different than with a single device. Existing profiling and debugging are not sufficient because the root cause very often is due to failures on network connections, hardware failure on particular chip, or vm failure on particular host… Therefore, new fleet level tools are necessary to help users to navigate. At the same time, we are also aiming to make TensorBoard better.

Flash attention

Owner: Jiewen Tan

One of the trend of LLM training has become using longer and longer sequence length. To mitigate that, we are planning to integrate flash attention to improve our performance on dealing with 8K or even 16K context length.


Owner: Milad Mohammadi, Han Qi

vLLM stands for “Very Large Language Model”. It is a high-throughput and memory-efficient inference and serving engine for LLMs. vLLM utilizes a new attention algorithm called PagedAttention, which effectively manages attention keys and values. This allows vLLM to achieve significantly higher throughput than other LLM serving systems, without requiring any model architecture changes. We are planning to make that availabe in TPU.


Owner: Siyuan Liu

AutoGPTQ is an easy-to-use large language model (LLM) quantization package with user-friendly APIs, based on the GPTQ algorithm. It allows users to quantize pretrained HF Transformers models to reduce their memory footprint and inference latency, with minimal loss in accuracy. We are planning to make that avaliable in TPU.




Owner: Will Cromar Xiongfei Wei

We have deprecated our old XRT runtime and made the PJRT runtime to be our only runtime. It has been pretty stable in the past half year and consistently provides better performance and debuggability.

Release 2.2 - 2.3 Feature Plan

Xiongfei Wei is working with wbmc to bring multi-host GPU support to the PyTorch/XLA PJRT. I will talk more in the GPU section.




Owner: Manfei Bai Will Cromar

PyTorch/XLA has transitioned from depending on TensorFlow to depending on the new OpenXLA repo. This allows us to reduce our binary size and simplify our build system. Starting from 2.1, PyTorch/XLA will release our TPU whl on the pypi.

Release 2.2 - 2.3 Feature Plan

The focus will be to maintain the stability of this dependency.




Owner: Meghan Yeounoh Chung

The original AMP for PyTorch/XLA was contributed by Chengji from Bytedance. The implementation was to intercept the AMP call and reuse the CUDA’s AMP for PyTorch/XLA GPU. In 2.1, Meghan implemented the AMP for TPU. Right now, if users use the AMP for PyTorch/XLA GPU, it will still use the same CUDA AMP rules upstream. For AMP + PyTorch/XLA TPU, it will use another set of basic AMP rules. For users, it should be transparent to switch to the actual XLA device. More details can be found in this doc.




Owner: Jack Cao Wonjoo Lee

CPU fallback

Owner: Wonjoo Lee

One major drawback of the PyTorch/XLA dynamo integration is that PyTorch/XLA will crash if the FX graph passed down contains operations that XLA does not support. Wonjoo Lee collaborated with Sean from AWS to implement the CPU fallback for these operations (shout out to Sherlock for the help).

Dynamo + SPMD

Owner: Yeounoh Chung Jack Cao

We implemented an experimental support for Dynamo and SPMD with the limitation that the mark_sharding call needs to happen outside of the torch.compile scope. The way it works is pretty smart/hacky.

xla_x = torch.randn(1, 128, device=device)

xs.mark_sharding(linear.fc2.weight, self._get_mesh((1, self.n_devices)),

(1, 0))

dynamo_linear = torch.compile(linear, backend="openxla")

dynamo_res = dynamo_linear(xla_x)

mark_sharding will attach the sharding annotation to the given torch tensor that is on the XLA device. During the compilation phase of the dynamo, dynamo will pass the real tensor with the sharding annotation to the XLA dynamo bridge which is enough for xla to perform the SPMD compilation.

Inference and traceable collectives

Owner: Jiewen Tan Liyang Lu Milad Mohammadi Wonjoo Lee Yeounoh Chung Jack Cao

In this year so far, we have done inference optimization for Llama and Llama2 both with Dynamo enabled on distributed devices. The native Dynamo greatly reduces the CPU tracing overhead and gives very competitive results on the newest generation of the Cloud TPU on a single device. However, it doesn’t work well with distributed devices until we adopt traceable collectives. Then, the performance excels for all model sizes, including 70B Llama 2. We believe that Dynamo for inference is a very promising solution and should be our default solution for inference. Especially for LLM, with latest technology such as vLLM, an economic, performant and robust Python only serving solution is totally feasible. Currently, the issue is around coverage. When I was trying to enable Dynamo for HF stable diffusion, I ran into a dynamo bug that I needed to investigate.

Release 2.2 - 2.3 Feature Plan

Single Step Graph (for training)

Owner: Jack Cao

Having forward and backward in separate graphs remains to be the biggest blocker for PyTorch/XLA to adopt Dynamo as the default training graph capture mechanism. I am working with Will, Voz and Ed to explore the path of generating the single graph. Many thanks to Brain and Voz for onboarding me on Dynamo and aot-autograd.

Handle unexpected reshape in aot-autograd

Owner: Wonjoo Lee

During our Llam2 experiment, one thing we noticed is that aot-autograd might apply reshape to input tensors to dynamo graphs between runs. I talked to Brian and this is expected since reshape is not considered as a computation for dynamo. However, reshape/view is an actual computation in XLA. Wonjoo Lee will take the lead on investigating this issue and try to remove unexpected reshape added by aot-autograd.

Support SPMD activation sharding in dynamo

Owner: Wonjoo Lee

As we discussed earlier, mark_sharding call inside the torch.compile region is currently not supported. One potential solution we discussed with Ed is to register the pybind that mark_sharding calls as a custom call, so it can be represented in the FX graph. Otherwise when dynamo will crash upon encountering a pybind it does not understand.




Owner: Xiongfei Wei quansight team

Release 2.2 - 2.3 Feature Plan

Multi Host GPU

Owner Xiongfei Wei

Xiongfei Wei is working with wbmc to bring multi-host GPU support to the PyTorch/XLA PJRT. The current plan is to support torchrun to brings up multi-host training and demonstrative scalability and performance on mulit-host GPU


We partner up with quansight to initiate the effort to benchmark the models in torchbench and understand the performance of the PyTorch/XLA GPU backend. The goal is to increase the model coverage and work with the XLA:GPU team to optimize the performance.




Owner: Han Qi Siyuan Liu

We added APIs in PyTorch/XLA to lower exported PyTorch models to StableHLO. The exported StableHLO can be consumed by Tensorflow Serving for inference applications and any HLO or StableHLO-compatible compiler.

More details please refer to

Release 2.2 - 2.3 Feature Plan

As more features are added to torch.export (e.g. Scalar type capturing), we will keep torch.export support in PyTorch/XLA updated.

Symbolic shape support: if torch.export returns a GraphModule with symbolic shapes (with SymInt in meta[‘va’] dict) then, in the corresponding positions, we will emit StableHLO MLIR with unknown dimension sizes (e.g. tensor<?x5xf32>)




Owner: Siyuan Liu Han Qi

We benchmarked the inference performance of weight-only quantized LLaMA on TPU v4 (Check out our previous blog post). A user guide is provided in the 2.1 release for model developers to enable weight-only quantized models on TPU.

Release 2.2 - 2.3 Feature Plan

We are working on allowing PyTorch users to quantize models with PyTorch PT2E quantization and deploy models on TPU/GPU. Besides PT2E quantization, we are also actively working on enabling users to deploy quantized models from HuggingFace on TPU. As int4 quantization is more and more popular recently, we plan to support int4 in PyTorch/XLA.




Owner: Will Cromar Jack Cao

Release 2.2 - 2.3 Feature Plan

Minimize code change required from PyTorch native to PyTorch/XLA

Owner: Will Cromar

PyTorch/XLA today maintains a set of private api that users have to apply to their model code, this has become a blocker for users to migrate to PyTorch/XLA seemlessly. The goal of this project is to reduce/upstream as much private api as possible.

Enhance debugging tool

Owner: Jack Cao

It is confusing to most users when recompilation happens in the XLA layer and performance becomes terrible. PyTorch/XLA should do a better job of pointing out where the dynamism in the graph(dynamic shape? Control flow) is from and gives suggestions to how to fix the code.

Better DDP support

Owner: TBD

PyTorch/XLA has been using our own cc ops(xm.all_reduce, xm.reduce_scater) in most of our tutorials. We added DDP support last year but there are few open issues. The goal of this project is to improve the DDP support and replace PyTorch/XLA’s own cc op for data parallel.




Owner: Jiewen Tan Wonjoo Lee

We adopted the functionalization upstream(check out Brian’s post) in favor of PyTorch/XLA’s own view implementation (for the history of this implementation, check here). This allows us to reduce the complexity of our code and fix a longstanding issue where view relationships will be invalid after a mark_step. For the 2.1 release we provided a flag XLA_DISABLE_FUNCTIONALIZATION to make it configurable but the long term plan is to remove this flag.

Release 2.2 - 2.3 Feature Plan

Optimizer Tracing Time Regression

Owner: Wonjoo Lee

We noticed some tracing time regression on Adam optimizer. In some models this is more problematic because execution time is not long enough to hide the tracing time. Wonjoo Lee will investigate this regression and validate the result on a set of torchbench models to make sure we do not regress the performance.



Bounded Dynamic Shape

Owner: Xiongfei Wei

The scope of this project was to support the bounded dynamic shape execution on TPU and GPU through XLA. This allows ops like nonzero to generate an output with upperbound at dynamic dimension. Xiongfei Wei made quite a bit of progress on reducing the recompilation count on a simple linear model with nonzero in the middle. However due to the complexity of the project, team priority and resourcing issue we have paused the development on bounded dynamic shape.

Release 2.2 - 2.3 Feature Plan

We will maintain all of the existing infras and tests we developed for the bounded dynamic shape and hopefully can resume the work in the future.



TPU Specific

Cloud TPU v5e support

Owner: Manfei Bai, Will Cromar, Jon Bolin, Jiewen Tan

Tremendous amount of effort has been made to let Cloud TPU v5e produce the best in class Perf/TCO for inference. And we are working on bringing that to training as well.

Cloud TPU MultiSlice support

Owner: Jon Bolin, Mohit Khatwani, Jiewen Tan

Traditionally, TPUs are connected through ICI within a cluster called pod. The latest development of LLM demands huge computing power that not even a full TPU pod can suffice. Hence, recently we have developed technology that allows TPU slices to be connected through DCN as well. That not only allows us to harness the compute power of multiple TPU pods but also allows us to interconnect TPU slices within a pod more flexibly.




Owner: TBD

Release 2.2 - 2.3 Feature Plan

PyTorch/XLA aims to provide an api to execute the Triton GPU kernel using PyTorch/XLA:GPU through XLA custom call. This will allow users to combine their triton kernels with compiled XLA programs seamlessly.


2.1 is a huge release for PyTorch/XLA and looking forward to bring all those road map items into reality.