It has been a while since I wrote the last dev update for PyTorch/XLA, the original plan was to give this kind of update quarterly but my long vacation earlier this year ruined the schedule a bit. I have been asked by a couple folks about the roadmap of the PyTorch/XLA, going forward I will try to write this dev update at least per release.
The content of this document will overlap with our release notes. The idea for this doc is for me to share why we are making some technical decisions and share our plan for the next few months.
GSPMD
Owner: Yeounoh Chung Jon Bolin Jiewen Tan Jack Cao
API
Owner: Yeounoh Chung
We introduce the mark_sharding api that can assign tensor a sharding annotation. The idea is users only have to call this api on key tensors and let the compiler handle the actual communication across accelerators. More technical details can be found in this blog post.
Benchmark and Performance Optimization
Owner: Jiewen Tan Jon Bolin Yeounoh Chung
We used Llama2 and GPT2 to do the performance validation for GSPMD on TPU. For Llama2 we used this gihutb fork and achieved 53%+ model FLOPs utilizations on 7B, 13B, and 70B Llama2. There will be a blog post to share more details soon.
Release 2.2 - 2.3 Feature Plan
GPU SPMD
Owner: Yeounoh Chung Xiongfei Wei
The early experiments are mostly done on TPUs. Our friends at Alibaba have been testing GSPMD on GPUs and achieving promising results. We also plan to increase the coverage and performance validation on GPUs.
Auto Sharding
Owner: Yeounoh Chung
There is tooling that exists in OpenXLA which provides automatic sharding annotation to a HLO graph. We plan to test out this feature and bring it to the PyTorch/XLA.
Distributed Tensor Integration
Owner: Yeounoh Chung Jiewen Tan
Our team has been working closely with Wanchao and Junjie on the Distributed Tensor integration since both projects started at roughly the same time. Right now Yeounoh Chung and Jiewen Tan are working on a POC to integrate with the Distributed Tensor API so users can use a unified API to do tensor parallelism. We believe this is a good integration point since the Distributed Tensor API is high level enough that each backend can handle the actual sharding implementation differently(I believe that for PyTorch native decompositions, they happen at the framework level, and for XLA collective ops decompositions, they happen at the compiler level).
There is one major difference between XLA GSPMD and PyTorch Native DTensor which is the number of training processes. For PyTorch/XLA GSPMD we only expect to spawn one process per host and Native DTensor will spawn one process per device. From XLA’s perspective using only one process per device will effectively reduce the number of compilation we need to perform per host. XLA compilation can be expensive. XLA GSPMD expects a single process per host to execute on all accessible devices… We are aware that this divergence might cause some confusion, we will discuss this topic further with Wanchao.
SPMD Debugging Tools
Owner: Manfei Bai Yeounoh Chung
SPMD computation can be difficult reason about, especially the sharding annotations and the actual shard whereabout. We plan to add a debugging package to help debugging SPMD workloads.
LLM Toolchain
Owner: Jiewen Tan, Jon Bolin, Mohit Khatwani, Han Qi, Milad Mohammadi
Distributed Checkpointing
Owner: Jon Bolin
Jon Bolin has been working with Rodrigo on distributed checkpointing. The plan is to integrate with the upstream checkpoint api. We already have an early POC working.
Release 2.2 - 2.3 Feature Plan
The focus will be expanding the support to provide a complete toolchain to make LLM UX frictionless from training to inference.
HuggingFace/Lightning SPMD Integration
Owner: Jiewen Tan, Jon Bolin
SPMD is so far the SOTA to achieve the best in-class distributed training performance and yet maintain an ease-to-use UX. We are planning to upstream our Llama 2 integration to HF and generalize it for most dense decoder-only transformers. In parallel, we are also working with Lightning as well.
Deterministic data loader
Owner: Mohit Khatwani
Besides checkpointing, the ability to deterministically re-enter the last data loader is also crucial to improve the whole training system goodput. Right now, we are exploring the integration with torchdata.DataLoader2 or Grain.
TPU orchestration
Owner: Allen Wang, Jon Bolin
Distributing training jobs among thousands of chips is a different monster than within a single host. We are aiming to provide tpu orchestration managers that highly intergrates with ray and GKE to help users from writing scripts that deal with barebone ssh. One step further, these helpers here can be expanded to deploy on mutli-host GPU training as well.
Profiling and debuggability
Owner: Jiewen Tan
Profiling and debugging with thousands of chips is also drastically different than with a single device. Existing profiling and debugging are not sufficient because the root cause very often is due to failures on network connections, hardware failure on particular chip, or vm failure on particular host… Therefore, new fleet level tools are necessary to help users to navigate. At the same time, we are also aiming to make TensorBoard better.
Flash attention
Owner: Jiewen Tan
One of the trend of LLM training has become using longer and longer sequence length. To mitigate that, we are planning to integrate flash attention to improve our performance on dealing with 8K or even 16K context length.
vLLM
Owner: Milad Mohammadi, Han Qi
vLLM stands for “Very Large Language Model”. It is a high-throughput and memory-efficient inference and serving engine for LLMs. vLLM utilizes a new attention algorithm called PagedAttention, which effectively manages attention keys and values. This allows vLLM to achieve significantly higher throughput than other LLM serving systems, without requiring any model architecture changes. We are planning to make that availabe in TPU.
AutoGPTQ
Owner: Siyuan Liu
AutoGPTQ is an easy-to-use large language model (LLM) quantization package with user-friendly APIs, based on the GPTQ algorithm. It allows users to quantize pretrained HF Transformers models to reduce their memory footprint and inference latency, with minimal loss in accuracy. We are planning to make that avaliable in TPU.
PJRT
Owner: Will Cromar Xiongfei Wei
We have deprecated our old XRT runtime and made the PJRT runtime to be our only runtime. It has been pretty stable in the past half year and consistently provides better performance and debuggability.
Release 2.2 - 2.3 Feature Plan
Xiongfei Wei is working with wbmc to bring multi-host GPU support to the PyTorch/XLA PJRT. I will talk more in the GPU section.
OpenXLA
Owner: Manfei Bai Will Cromar
PyTorch/XLA has transitioned from depending on TensorFlow to depending on the new OpenXLA repo. This allows us to reduce our binary size and simplify our build system. Starting from 2.1, PyTorch/XLA will release our TPU whl on the pypi.
Release 2.2 - 2.3 Feature Plan
The focus will be to maintain the stability of this dependency.
AMP for TPU
Owner: Meghan Yeounoh Chung
The original AMP for PyTorch/XLA was contributed by Chengji from Bytedance. The implementation was to intercept the AMP call and reuse the CUDA’s AMP for PyTorch/XLA GPU. In 2.1, Meghan implemented the AMP for TPU. Right now, if users use the AMP for PyTorch/XLA GPU, it will still use the same CUDA AMP rules upstream. For AMP + PyTorch/XLA TPU, it will use another set of basic AMP rules. For users, it should be transparent to switch to the actual XLA device. More details can be found in this doc.
Dynamo
Owner: Jack Cao Wonjoo Lee
CPU fallback
Owner: Wonjoo Lee
One major drawback of the PyTorch/XLA dynamo integration is that PyTorch/XLA will crash if the FX graph passed down contains operations that XLA does not support. Wonjoo Lee collaborated with Sean from AWS to implement the CPU fallback for these operations (shout out to Sherlock for the help).
Dynamo + SPMD
Owner: Yeounoh Chung Jack Cao
We implemented an experimental support for Dynamo and SPMD with the limitation that the mark_sharding call needs to happen outside of the torch.compile scope. The way it works is pretty smart/hacky.
xla_x = torch.randn(1, 128, device=device)
xs.mark_sharding(linear.fc2.weight, self._get_mesh((1, self.n_devices)),
(1, 0))
dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
mark_sharding will attach the sharding annotation to the given torch tensor that is on the XLA device. During the compilation phase of the dynamo, dynamo will pass the real tensor with the sharding annotation to the XLA dynamo bridge which is enough for xla to perform the SPMD compilation.
Inference and traceable collectives
Owner: Jiewen Tan Liyang Lu Milad Mohammadi Wonjoo Lee Yeounoh Chung Jack Cao
In this year so far, we have done inference optimization for Llama and Llama2 both with Dynamo enabled on distributed devices. The native Dynamo greatly reduces the CPU tracing overhead and gives very competitive results on the newest generation of the Cloud TPU on a single device. However, it doesn’t work well with distributed devices until we adopt traceable collectives. Then, the performance excels for all model sizes, including 70B Llama 2. We believe that Dynamo for inference is a very promising solution and should be our default solution for inference. Especially for LLM, with latest technology such as vLLM, an economic, performant and robust Python only serving solution is totally feasible. Currently, the issue is around coverage. When I was trying to enable Dynamo for HF stable diffusion, I ran into a dynamo bug that I needed to investigate.
Release 2.2 - 2.3 Feature Plan
Single Step Graph (for training)
Owner: Jack Cao
Having forward and backward in separate graphs remains to be the biggest blocker for PyTorch/XLA to adopt Dynamo as the default training graph capture mechanism. I am working with Will, Voz and Ed to explore the path of generating the single graph. Many thanks to Brain and Voz for onboarding me on Dynamo and aot-autograd.
Handle unexpected reshape in aot-autograd
Owner: Wonjoo Lee
During our Llam2 experiment, one thing we noticed is that aot-autograd might apply reshape to input tensors to dynamo graphs between runs. I talked to Brian and this is expected since reshape is not considered as a computation for dynamo. However, reshape/view is an actual computation in XLA. Wonjoo Lee will take the lead on investigating this issue and try to remove unexpected reshape added by aot-autograd.
Support SPMD activation sharding in dynamo
Owner: Wonjoo Lee
As we discussed earlier, mark_sharding call inside the torch.compile region is currently not supported. One potential solution we discussed with Ed is to register the pybind that mark_sharding calls as a custom call, so it can be represented in the FX graph. Otherwise when dynamo will crash upon encountering a pybind it does not understand.
GPU
Owner: Xiongfei Wei quansight team
Release 2.2 - 2.3 Feature Plan
Multi Host GPU
Owner Xiongfei Wei
Xiongfei Wei is working with wbmc to bring multi-host GPU support to the PyTorch/XLA PJRT. The current plan is to support torchrun to brings up multi-host training and demonstrative scalability and performance on mulit-host GPU
Benchmark
We partner up with quansight to initiate the effort to benchmark the models in torchbench and understand the performance of the PyTorch/XLA GPU backend. The goal is to increase the model coverage and work with the XLA:GPU team to optimize the performance.
Export
Owner: Han Qi Siyuan Liu
We added APIs in PyTorch/XLA to lower exported PyTorch models to StableHLO. The exported StableHLO can be consumed by Tensorflow Serving for inference applications and any HLO or StableHLO-compatible compiler.
More details please refer to https://github.com/pytorch/xla/blob/master/docs/stablehlo.md
Release 2.2 - 2.3 Feature Plan
As more features are added to torch.export (e.g. Scalar type capturing), we will keep torch.export support in PyTorch/XLA updated.
Symbolic shape support: if torch.export returns a GraphModule with symbolic shapes (with SymInt in meta[‘va’] dict) then, in the corresponding positions, we will emit StableHLO MLIR with unknown dimension sizes (e.g. tensor<?x5xf32>)
Quantization
Owner: Siyuan Liu Han Qi
We benchmarked the inference performance of weight-only quantized LLaMA on TPU v4 (Check out our previous blog post). A user guide is provided in the 2.1 release for model developers to enable weight-only quantized models on TPU.
Release 2.2 - 2.3 Feature Plan
We are working on allowing PyTorch users to quantize models with PyTorch PT2E quantization and deploy models on TPU/GPU. Besides PT2E quantization, we are also actively working on enabling users to deploy quantized models from HuggingFace on TPU. As int4 quantization is more and more popular recently, we plan to support int4 in PyTorch/XLA.
Usability
Owner: Will Cromar Jack Cao
Release 2.2 - 2.3 Feature Plan
Minimize code change required from PyTorch native to PyTorch/XLA
Owner: Will Cromar
PyTorch/XLA today maintains a set of private api that users have to apply to their model code, this has become a blocker for users to migrate to PyTorch/XLA seemlessly. The goal of this project is to reduce/upstream as much private api as possible.
Enhance debugging tool
Owner: Jack Cao
It is confusing to most users when recompilation happens in the XLA layer and performance becomes terrible. PyTorch/XLA should do a better job of pointing out where the dynamism in the graph(dynamic shape? Control flow) is from and gives suggestions to how to fix the code.
Better DDP support
Owner: TBD
PyTorch/XLA has been using our own cc ops(xm.all_reduce
, xm.reduce_scater
) in most of our tutorials. We added DDP support last year but there are few open issues. The goal of this project is to improve the DDP support and replace PyTorch/XLA’s own cc op for data parallel.
Functionalization
Owner: Jiewen Tan Wonjoo Lee
We adopted the functionalization upstream(check out Brian’s post) in favor of PyTorch/XLA’s own view implementation (for the history of this implementation, check here). This allows us to reduce the complexity of our code and fix a longstanding issue where view relationships will be invalid after a mark_step
. For the 2.1 release we provided a flag XLA_DISABLE_FUNCTIONALIZATION
to make it configurable but the long term plan is to remove this flag.
Release 2.2 - 2.3 Feature Plan
Optimizer Tracing Time Regression
Owner: Wonjoo Lee
We noticed some tracing time regression on Adam optimizer. In some models this is more problematic because execution time is not long enough to hide the tracing time. Wonjoo Lee will investigate this regression and validate the result on a set of torchbench models to make sure we do not regress the performance.
Bounded Dynamic Shape
Owner: Xiongfei Wei
The scope of this project was to support the bounded dynamic shape execution on TPU and GPU through XLA. This allows ops like nonzero to generate an output with upperbound at dynamic dimension. Xiongfei Wei made quite a bit of progress on reducing the recompilation count on a simple linear model with nonzero in the middle. However due to the complexity of the project, team priority and resourcing issue we have paused the development on bounded dynamic shape.
Release 2.2 - 2.3 Feature Plan
We will maintain all of the existing infras and tests we developed for the bounded dynamic shape and hopefully can resume the work in the future.
TPU Specific
Cloud TPU v5e support
Owner: Manfei Bai, Will Cromar, Jon Bolin, Jiewen Tan
Tremendous amount of effort has been made to let Cloud TPU v5e produce the best in class Perf/TCO for inference. And we are working on bringing that to training as well.
Cloud TPU MultiSlice support
Owner: Jon Bolin, Mohit Khatwani, Jiewen Tan
Traditionally, TPUs are connected through ICI within a cluster called pod. The latest development of LLM demands huge computing power that not even a full TPU pod can suffice. Hence, recently we have developed technology that allows TPU slices to be connected through DCN as well. That not only allows us to harness the compute power of multiple TPU pods but also allows us to interconnect TPU slices within a pod more flexibly.
Triton
Owner: TBD
Release 2.2 - 2.3 Feature Plan
PyTorch/XLA aims to provide an api to execute the Triton GPU kernel using PyTorch/XLA:GPU through XLA custom call. This will allow users to combine their triton kernels with compiled XLA programs seamlessly.