Float8 in PyTorch [1/x]

This note introduces the experimental support for float8 training in PyTorch. The float8 data formats (as laid out in https://arxiv.org/pdf/2209.05433.pdf) have been added in recent hardware such as NVIDIA Hopper. They have theoretical advantages of 2x in throughput and memory usage over 16-bit formats, which is why we want to make it easy to use these dtypes in PyTorch. Today, we wanted to share:

  1. Primitives to express a float8 matrix multiplication with per-tensor scaling

  2. the torch.float8_e4m3fn and torch.float8_e5m2 dtypes

  3. the torch._scaled_mm op

  4. float8_experimental, a lightweight library for accelerating training with float8 in native PyTorch with support for torch.compile and distributed. Initial results show throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs. Peak memory usage improvements and large scale distributed support are coming soon.

The rest of this note is laid out as follows:

  1. Float8 training building blocks, ELI5
  2. Float8 design choices
  3. State of float8 in PyTorch today
  4. Upcoming work

Float8 training building blocks, ELI5

This section will explain what’s going on under the hood of float8 training and how PyTorch is implementing the required building blocks.

The need for per-tensor scaling

Current SOTA float8 training recipes (https://arxiv.org/pdf/2209.05433.pdf, Section 4.3) require per-tensor scaling of weights and activations during training to achieve competitive accuracy.

Float8 matmul with per-tensor scaling

Here is pseudocode for calling a bf16 matmul:

# one liner!
C_bf16 = A_bf16 @ B_bf16

And here is pseudocode for calling a float8 matmul with per-tensor scaling:

# calculate per-tensor scales
A_s = get_scale(A_bf16, ...)
B_s = get_scale(B_bf16, ...)

# scale and cast to float8
A_f8 = (A_bf16 * A_s).to(torch.float8_e4m3fn)
B_f8 = (B_bf16 * B_s).to(torch.float8_e4m3fn)
# perform the matmul with scaled float8 inputs
C_bf16 = scaled_mm(A_f8, 1 / A_s, B_f8, 1 / B_s, ...)

We need a couple of things to make the latter work in PyTorch

  1. Float8 primitives (dtypes + matmul) in core, for easy development out of core
  2. Per-tensor scaling, to achieve competitive accuracy
  3. Composability with systems such as autograd, torch.compile, inductor and distributed for competitive performance

Float8 primitives in core

We added the following primitives to PyTorch core to enable float8 modeling:

  • torch.float8_e4m3fn and torch.float8_e5m2 dtypes, matching the spec described in [2209.05433] FP8 Formats for Deep Learning.
  • torch._scaled_mm function, which wraps the cuBLAS float8 matmul routine and is about 2x faster than the bf16 mm on common LLaMa 70B shapes on an NVIDIA H100-SXM GPU.

A deep dive into per-tensor scaling

There are three common approaches to per-tensor scaling: dynamic, delayed and static. Dynamic and delayed scaling are supported by float8_experimental today.

Dynamic scaling

Dynamic scaling calculates the scale just-in-time based on the current tensor. Pseudocode:

amax = abs(max(x))
scale = amax_to_scale(amax, torch.float8_e4m3fn)
x_fp8 = (x * scale).to(torch.float8_e4m3fn)

An advantage of this approach is accuracy (the scale is always based on the values of the current tensor) and simplicity (stateless, can be implemented with torch_dispatch overrides).

A drawback of this approach is performance. Calculating the amax requires a reduction to a single element and because the amax value is needed for the scale it is difficult to fully fuse this with the subsequent scaled cast.

Delayed scaling

Delayed scaling is a more fusion friendly algorithm for per-tensor scaling. Pseudocode:

prev_amax = get_prev_amax(amax_history)
cur_scale = amax_to_scale(prev_amax, torch.float8_e4m3fn)
x_fp8 = (x * cur_scale).to(torch.float8_e4m3fn)
cur_amax = abs(max(x))
store_cur_amax_for_next_iteration(cur_amax)

Note that calculating cur_amax can now overlap with the scaled cast, leading to a potential performance improvement over delayed scaling.

A drawback of delayed scaling is implementation complexity. We need to track persistent state for each float8 cast op in the computational graph, which requires a model rewrite to implement cleanly. In this way, float8 training with delayed scaling is more akin to quantization-aware training than to lighter-weight features such as model.to(bfloat16) or automated mixed precision.

Static scaling

Static scaling pre-calculates the scales of each tensor offline. Since this requires the weights to be frozen for good accuracy, this is usually an inference-only technique. In pseudocode:

scale = get_precalculated_scale(...)
x_fp8 = (x * scale).to(torch.float8_e4m3fn

The advantage is highest performance (the overhead of amax calculation is completely removed), and the drawbacks are statefulness (needs a model rewrite) and reduced accuracy (because all the scales are precomputed).

Composability with key PyTorch systems

This section details how float8 training interacts with key PyTorch systems, and the work done to enable this.

autograd

In the context of float8 training, for a tensor x we usually need x.dtype to be float8 but x.grad.dtype to be bfloat16. Autograd currently enforces x.dtype to equal x.grad.dtype for historical reasons. To get around this restriction we use Float8Tensor, which stores the raw data in float8 but advertises its dtype to autograd as bfloat16.

torch.compile

Scaling and casting tensors to float8 introduces overhead; we accept this overhead in eager mode to keep the simple and depend on torch.compile + inductor to recover performance. For example, LLaMa 7B training with float8 dynamic scaling has a speedup of 0.81 over bf16 in eager mode, and 1.22 with torch.compile. Please see the performance section later in this document for detailed speedups.

We are now fully torch.compile compatible for single GPU as well as 16-bit all-gather FSDP float8 code, due to the new torch.compile support for traceable tensor subclasses, tensor subclass + torch.autograd.Function, and inplace buffer mutations in the backward.

inductor

After we get a graph from torch.compile, we use inductor to generate kernels for amax and scaled cast fused into surrounding ops. We added inductor support for float8 dtypes, and optimized code generation to be performant for amax calculation, scaling and float8 cast necessary for Float8Linear.

distributed

Given a float8 matmul C_bf16 = A_fp8 @ B_fp8 and a distributed paradigm that communicates any of A or B across ranks, there are opportunities to do the communication in float8 to save communication latency by an upper bound of 50% over bf16. Below is an overview of how this applies to common distributed paradigms. Note that today, float8_experimental only composes with 16-bit all-gather FSDP. FSDP with float8 all-gather and DTensor composability are planned in 2024.

FSDP

If the gemm compute is using float8 then doing the all-gather for matmul weights in float8 is a free comm latency win. However, this is not possible to implement cleanly using today’s FSDP design. We are working on per-parameter FSDP in 2024 which will support float8 all-gather.

Float8 reductions are currently an open area of research and we are following the developments.

TensorParallel / SequenceParallel

In TP/SP (example from MegatronLM, https://arxiv.org/pdf/2205.05198.pdf, Figure 5), if the linear compute is using float8 then doing the all_gathers in both g and g_bar is a free latency win, and reduce-scatters are out of scope by the same reasoning as the FSDP case. float8_experimental currently supports float8 all-gather over Fairscale’s RowParallelLinear and ColumnParallelLinear TP/SP implementation, but this is not compatible with torch.compile. Our plan for enabling torch.compile support is composing with DTensor (below).

DTensor

DTensor is a new way to write distributed modeling code in PyTorch without requiring model code changes. We plan to enable composing float8 with DTensor to support TP/SP + float8 + torch.compile.

Float8 design choices

Float8 dtypes are unscaled

torch.float8_e4m3fn and torch.float8_e5m2 are unscaled, with the expectation that the higher level modeling code will track the scales. Ops such as torch.scaled_mm which require scales take them as separate arguments. This allows for flexibility when implementing scaling in the modeling layer, such as extending scales from per-tensor a higher granularity. We use tensor subclasses such as Float8Tensor for easy tracking of raw_data + scale in the modeling layer.

Use model rewrites to implement per-tensor scaling

The current SOTA scaling strategy is delayed scaling; this requires stateful per-tensor statistics collection for a subset of weights, activations, and gradients. A model rewrite is necessary to implement this cleanly; lighter-weight approaches such as automated mixed precision are not expressive enough. Even for stateless scaling strategies such as delayed scaling, a model rewrite implementation allows them to easily be compared with stateful strategies.

The current model rewrite approach we are using in float8_experimental is module swaps. In the future, we may explore module hooks and graph capture + graph pass to cover more cases.

Eager debuggability, compiled performance

Per-tensor scaling adds overhead via the calculation of abs(max(tensor)) and dtype casts. We keep our code simple in eager mode for easy debuggability and to minimize the need to write kernels. We depend on torch.compile + inductor to recover performance.

Distributed

We plan to compose with DTensor for our long term distributed support. Note that this functionality is not fully implemented yet.

State of float8 in PyTorch, as of 2024-01-16

  • float8 primitives (dtypes, matmul) are landed to core

  • autograd, torch.compile and inductor support float8 scaling, casting and gemms

  • we open sourced float8_experimental: a library for float8 acceleration training with native PyTorch. Highlights:

    • an easily hackable/debuggable codebase written in 4.5k LOC of Python
    • supports delayed and dynamic scaling
    • supports torch.compile
    • supports FSDP with 16-bit all-gather
    • FSDP with float8 all-gather and TP/SP support are planned
    • accuracy: no significant degradation from bf16 training on on LLaMa 7B and LLaMa 70B
    • LLaMa 7B / 1 GPU pretraining: 1.22x speedup with a 1.07x peak memory increase over bf16. LLaMa 13B / 8 GPUs: 1.20x speedup with a 1.22x peak memory increase over bf16 For memory overhead, the results are better for single GPU (1.1x) compared to X GPUs (1.22x) because our current FSDP integration multiplies this memory overhead (we plan to address this).

float8_experimental performance on small scale LLaMa

Key things to note:

  1. The e2e speedup is bounded by the percentage of time spent in GEMMs.
  2. In eager mode, PyTorch’s float8 is slower than bf16. This is expected as we explicitly require torch.compile for performance.
  3. PyTorch float8’s dynamic and delayed scaling are close in performance with the current implementation. Performance work in inductor to speed up both dynamic and delayed scaling is planned.

float8_experimental memory usage on small scale LLaMa

The naive implementation of float8 training increases memory usage as we save extra float8 activations and weights for backward, which is what we have today. Improving this via allowing configurable recomputations of float8 casts in the backward is planned.

Key things to note:

  1. PyTorch float8 + single GPU: we observe an up to 1.11x increase in memory usage due to storing extra float8 weights and activations for the backward. We plan to add an option to let the user choose to recompute them instead.
  2. PyTorch float8 + multi-GPU FSDP: we observe a significant memory regression which scales with the number of GPUs. This is happening because we are saving the unsharded float8 version of weight for the backward. We plan to fix this in the same manner as (2) and ensure this feature composes with FSDP’s unshard->free logic.
  3. Our TP/SP support does not work with torch.compile yet, which is why we stopped at LLaMa 13B for this note.

Upcoming work

Two major improvements are in the works:

  1. improved composability with distributed. Specifically, we plan to reduce peak memory usage when using float8 with FSDP, enable FSDP with 8-bit all-gather, and compose with DTensor to enable TP/SP.
  2. native support for float8 inference

Stay tuned for more updates, and please feel free to file issues in GitHub - pytorch-labs/float8_experimental: This repository contains the experimental PyTorch native float8 training UX.

9 Likes