PyTorch/XLA 2022 Q4 Dev update

PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and backends supports XLA including Cloud TPU, GPU, AWS Trainium and AWS inferentia. The project current lives in the github.

I want to give a summary of the development work happened in the multiple projects during Q4 2022.

SPMD (Single Program Multiple Data)

  • Features
    • Released XLAShardedTensor & sharding annotation API [RFC] in 1.13
    • Virtual device optimization (PR)
  • To-do’s
    • Currently supports a single host TPU backend only. We will work on multi-host TPU enablement in 2023 H1.
    • Started running single-host benchmarking with Resnet50, Megatron-LM style linear, GPT-2 models. We will publish the test script and the preliminary results in the coming months.

TorchDynamo Integration

Dynamic Shapes:

  • Enabled alpha support for dynamic shape on multi-layer perceptron
    • Forward-pass NN model with dynamic input has been validated.
    • WIP: Backward-pass NN model with dynamic input support.
  • Improved test coverage for dynamic shape



  • Experimental support for TPUs with PJRT in 1.13
  • Single GPU support in nightly (multi-GPU support coming soon)


  • Improved ResNet training performance by 49% since Q2 2022.
  • Achieved 300-epochs training convergence (at 0.85 DICE score) on Unet3D on BraTS dataset with negligible convergence variance.

LTC (Lazy Tensor Core)

Hi team, we are thrilled to announce that Lazy Tensor Core (LTC) Migration | Final Phases ([EXTERNAL] Lazy Tensor Core (LTC) Migration | Final Phases) is finally done. In this final phases exercise, we have achieved:

  1. Merged and reused 56% of LazyTensor code in the XLATensor. Measured by numbers of methods being reused and overridden / numbers of total methods in LazyTensor. This number can be increased to 83% once Functionalization adoption is done. Total methods: 46; reused methods: 26.
  2. Separated XLAGraphExecutor out of XLATensor and reused 61% of LazyGraphExecutor (same measurement above, total methods: 85, reused methods: 52). Given the tighter connection between the actual backend and the graph executor, it’s hard to increase this number further. However, at least most of the interfaces and concepts are aligned.
  3. Made torch-gen flexible with lazy backends’ shape such that xla::Shape can be used in the code-gened IR. This is critical for PyTorch/XLA’s Dynamic Shapes development.
  4. Fully migrated to use the TORCH_LAZY_* counters and metrics such that they are consistent across LTC and PyTorch/XLA.
  5. Finished adopting applicable BackendInterfaces to embrace the LTC design principles.
  6. Experimented with adopting Functionalization and produced a PoC that validated the core concept: [EXTERNAL] Update on LTC Migration and Functionalization.
  7. Identified a restriction in LTCTensorImpl that makes it hard to adopt.

It has been a year since we started migrating to the LTC, which is a long journey. During this time, PyTorch has evolved to PyTorch 2.0 and introduced torch.compile, which our team is actively exploring how to integrate with (TorchDynamo Update 10: Integrating with PyTorch/XLA for Inference and Training). However, LTC remains as the core of PyTorch/XLA until we find a better way of lowering PyTorch models into XLA’s HLOs. Therefore, we are committed to maintain and further develop the LazyTensor technology together with the PyTorch team onwards. The completion of this LTC migration is the cornerstone of this commitment.

More on Functionalization

  • We are able to run Resnet with fake data, and here are the results on v4-8 (img/s):
    • With XRT:
Label Mean Median 90th % Std Dev CV
Nightly 1797.12 1795.47 1853.79 45.20 0.03
Functionalization branch 1901.86 1904.11 1944.47 33.75 0.02
  • With PJRT:
Label Mean Median 90th % Std Dev CV
Nightly 2342.51 2338.51 2402.58 41.01 0.02
Functionalization branch 2376.11 2372.60 2414.03 28.00 0.01

Migrate to OpenXLA

  • Uber help to refactor OpenXLA
  • TODO:
    • Refine the design doc
    • Migrate Pytorch/XLA from TensorFlow to OpenXLA according to design doc with PRs