PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and backends supports XLA including Cloud TPU, GPU, AWS Trainium and AWS inferentia. The project current lives in the github.
I want to give a summary of the development work happened in the multiple projects during Q4 2022.
SPMD (Single Program Multiple Data)
- Features
- To-do’s
- Currently supports a single host TPU backend only. We will work on multi-host TPU enablement in 2023 H1.
- Started running single-host benchmarking with Resnet50, Megatron-LM style linear, GPT-2 models. We will publish the test script and the preliminary results in the coming months.
TorchDynamo Integration
- Most recent update in TorchDynamo Update 10: Integrating with PyTorch/XLA for Inference and Training
- Inference
- Speed up is pretty impressive, the TPU speed analysis can be found in TPU dynamo speed up on inference analysis · Issue #4328 · pytorch/xla · GitHub
- Simple inference example can be found in xla/test_dynamo.py at master · pytorch/xla · GitHub
- Training
- Speed on TPU is under investigation. One cause of the slow down is fragmented graphs.
- More details can be found in training support for dynamo+torchxla integration by shunting314 · Pull Request #88449 · pytorch/pytorch · GitHub
- TODO
- Enable the fallback for dynamo, [Feature Request][XLA] Support fallback for the dynamo-xla bridge · Issue #1823 · pytorch/torchdynamo · GitHub
- Enable the optimizer graph capturing during training
- Further speed analysis on Dynamo+Training.
Dynamic Shapes:
- Enabled alpha support for dynamic shape on multi-layer perceptron
- Forward-pass NN model with dynamic input has been validated.
- WIP: Backward-pass NN model with dynamic input support.
- Improved test coverage for dynamic shape
FSDP
- Introduced a
auto_wrap
feature in add `auto_wrap_policy` into XLA FSDP for automatic wrapping by ronghanghu · Pull Request #4318 · pytorch/xla · GitHub - Pending pr in HF to add native FSDP support Enable PyTorch/XLA Fully Sharded Data Parallel (FSDP) for a Specific Class of Transformer Models by AlexWertheim · Pull Request #20774 · huggingface/transformers · GitHub
- TODO
- Add auto_wrap support in HF
PJRT
- Experimental support for TPUs with PJRT in 1.13
- Single GPU support in nightly (multi-GPU support coming soon)
Modeling:
- Improved ResNet training performance by 49% since Q2 2022.
- Achieved 300-epochs training convergence (at 0.85 DICE score) on Unet3D on BraTS dataset with negligible convergence variance.
LTC (Lazy Tensor Core)
Hi team, we are thrilled to announce that Lazy Tensor Core (LTC) Migration | Final Phases ([EXTERNAL] Lazy Tensor Core (LTC) Migration | Final Phases) is finally done. In this final phases exercise, we have achieved:
- Merged and reused 56% of LazyTensor code in the XLATensor. Measured by numbers of methods being reused and overridden / numbers of total methods in LazyTensor. This number can be increased to 83% once Functionalization adoption is done. Total methods: 46; reused methods: 26.
- Separated XLAGraphExecutor out of XLATensor and reused 61% of LazyGraphExecutor (same measurement above, total methods: 85, reused methods: 52). Given the tighter connection between the actual backend and the graph executor, it’s hard to increase this number further. However, at least most of the interfaces and concepts are aligned.
- Made torch-gen flexible with lazy backends’ shape such that xla::Shape can be used in the code-gened IR. This is critical for PyTorch/XLA’s Dynamic Shapes development.
- Fully migrated to use the TORCH_LAZY_* counters and metrics such that they are consistent across LTC and PyTorch/XLA.
- Finished adopting applicable BackendInterfaces to embrace the LTC design principles.
- Experimented with adopting Functionalization and produced a PoC that validated the core concept: [EXTERNAL] Update on LTC Migration and Functionalization.
- Identified a restriction in LTCTensorImpl that makes it hard to adopt.
It has been a year since we started migrating to the LTC, which is a long journey. During this time, PyTorch has evolved to PyTorch 2.0 and introduced torch.compile, which our team is actively exploring how to integrate with (TorchDynamo Update 10: Integrating with PyTorch/XLA for Inference and Training). However, LTC remains as the core of PyTorch/XLA until we find a better way of lowering PyTorch models into XLA’s HLOs. Therefore, we are committed to maintain and further develop the LazyTensor technology together with the PyTorch team onwards. The completion of this LTC migration is the cornerstone of this commitment.
More on Functionalization
- We are able to run Resnet with fake data, and here are the results on v4-8 (img/s):
- With XRT:
Label | Mean | Median | 90th % | Std Dev | CV |
---|---|---|---|---|---|
Nightly | 1797.12 | 1795.47 | 1853.79 | 45.20 | 0.03 |
Functionalization branch | 1901.86 | 1904.11 | 1944.47 | 33.75 | 0.02 |
- With PJRT:
Label | Mean | Median | 90th % | Std Dev | CV |
---|---|---|---|---|---|
Nightly | 2342.51 | 2338.51 | 2402.58 | 41.01 | 0.02 |
Functionalization branch | 2376.11 | 2372.60 | 2414.03 | 28.00 | 0.01 |
-
Details of the experiment can be found: functionalization experiment.
- Hardware: TPU v4-8
- Model: resnet50 with fake data (xla/test/test_train_mp_imagenet.py)
- Num epochs: 1
-
We have fixed all upstream CI failures, 25 failure test cases in total: Fixes for PyTorch/XLA functionalization integration by wonjoolee95 · Pull Request #88787 · pytorch/pytorch · GitHub.
-
We have shrunk the total number of XLA failure test case from 104 to 22 ([LTC] Functionalization tests).
- We have also marked the failure test cases as skip in the pull request (#4158), and the CIs are green now.
-
Wonjoo Lee have introduced a ‘keep going’ CI feature (inspired by upstream) to allow CI tests to continue running upon failures such that developers can catch all failures at once. This is particularly useful for features like Functionalization that refactors the whole stack. (#4385)
Migrate to OpenXLA
- Uber help to refactor OpenXLA
- TODO:
- Refine the design doc
- Migrate Pytorch/XLA from TensorFlow to OpenXLA according to design doc with PRs