PyTorch/XLA 2.4 dev update

PyTorch/XLA 2.4 Dev Update

Hey I am here to give a late update for the PyTorch/XLA 2.3 release. Similar to my previous update, you can check our release note for detailed updates. I am going to highlight some of the new features and share how I think about them.

Eager Mode

In 2.4 release we introduce the eager mode to the PyTorch/XLA, you can enable it via

torch_xla.experimental.eager_mode(True)

This might sound weird to you if you are not familiar with torch_xla. The default mode of the torch_xla is lazy tracing

a = torch.tensor(100, device=torch_xla.device())
# execution does not happen, we created a node called `COS` with an input
b = torch.cos(a)
c = torch.sin(B)

# execution happens here, we construct a graph to calculate the value of both b and c
torch_xla.sync()

The motivation of this approach was to accumulate as many pytorch ops as possible and execute them in one graph to optimize the performance. However over the year we found this mode to be too confusing so we decided to add the eager mode so it is closer to the native pytorch experience. The performance of the eager mode can be found in here but the TLDR is that it is very model dependent and ranges from 1% - 45% of the compiled performance.

Going forward we want users to mix eager mode with a compile API similar to the native pytorch. For inference it will looks like

torch_xla.experimental.eager_mode(True)

compiled_model = torch.compile(model, backend="openxla")
res = compiled_model(input)

For training we currently recommend

torch_xla.experimental.eager_mode(True)

def step_fn(model, data, target, loss_fn, optimizer):
    optimizer.zero_grad()
    logits = model(data)
    loss = loss_fn(logits, target)
    loss.backward()
    optimizer.step()
    return loss

step_fn = torch_xla.experimental.compile(step_fn)

And hope we can also eventually converge on using the torch.compile directly. To learn more about the eager mode + compile, please take a look at this doc.

Pallas Kernel

In this release we really try to harden our pallas kernel support. Pallas is a custom kernel language that supports both TPU and GPU. Pallas was born in JAX to support the need of implementing fusions without the XLA compiler and better control the memory access pattern. It got more popular on the TPU side and now there are a couple very popular custom kernels being open sourced in here.

PyTorch/XLA has made a lot of progress on porting the existing Pallas kernel like flash_attention, paged_attention, gmm (megablocks for moe), tgmm, into the PyTorch/XLA repo. Users can also call their own Pallas kernel as we have talked about in the last dev update. In order to use these kernels on training, we also did additional work to register the corresponding backward kernel, adding the support for GSPMD (GSPMD sharding propagation is disabled in Palls kernel since they are just black boxes to the compiler). You can take a look at this example for how to use our flash attention wrapper.

PyTorch/XLA team sees a rapid need for the custom kernel and our goal is to make this experience easier. For more information please take a look at this doc.

Triton kernel

In this release we also added the support for triton kernel on XLA:GPU. You can use it via

payload = xla_triton.triton_call(
    x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size)

output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload,
                                                [output.shape], [torch.int64])

The triton kernel will be treated as a black box and represented as a custom call in the HLO graph generated. For more detail you can take a look at this doc

Usability & Debuggability Improvements

In the past few releases PyTorch/XLA has come a long way to optimize the performance but the team is aware that usability still has many issues. On top of the eager mode change mentioned above, we also made a couple changes to clean up our api and add more debugging tools. For example

  1. torch_xla.sync() to replace the torch_xla.core.xla_model.mark_step()
  2. torch_xla.device() to replace the torch_xla.core.xla_model.xla_device()
  3. add the support for torch_xla.core.xla_model.get_memory_info to query the TPU HBM info
  4. add compiled program HBM usage analysis when running with PT_XLA_DEBUG=1
  5. add PT_XLA_DEBUG_LEVEL=1 to only output debug messages on new compilation.
  6. refresh our doc page

We plan to keep cleaning up our API and make it more aligned with the upstream while improving the debugging experience.

9 Likes