This note introduces the experimental support for float8 training in PyTorch. The float8 data formats (as laid out in https://arxiv.org/pdf/2209.05433.pdf) have been added in recent hardware such as NVIDIA Hopper. They have theoretical advantages of 2x in throughput and memory usage over 16bit formats, which is why we want to make it easy to use these dtypes in PyTorch. Today, we wanted to share:

Primitives to express a float8 matrix multiplication with pertensor scaling

the torch.float8_e4m3fn and torch.float8_e5m2 dtypes

the torch._scaled_mm op

float8_experimental, a lightweight library for accelerating training with float8 in native PyTorch with support for torch.compile and distributed. Initial results show throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs. Peak memory usage improvements and large scale distributed support are coming soon.
The rest of this note is laid out as follows:
 Float8 training building blocks, ELI5
 Float8 design choices
 State of float8 in PyTorch today
 Upcoming work
Float8 training building blocks, ELI5
This section will explain what’s going on under the hood of float8 training and how PyTorch is implementing the required building blocks.
The need for pertensor scaling
Current SOTA float8 training recipes (https://arxiv.org/pdf/2209.05433.pdf, Section 4.3) require pertensor scaling of weights and activations during training to achieve competitive accuracy.
Float8 matmul with pertensor scaling
Here is pseudocode for calling a bf16 matmul:
# one liner!
C_bf16 = A_bf16 @ B_bf16
And here is pseudocode for calling a float8 matmul with pertensor scaling:
# calculate pertensor scales
A_s = get_scale(A_bf16, ...)
B_s = get_scale(B_bf16, ...)
# scale and cast to float8
A_f8 = (A_bf16 * A_s).to(torch.float8_e4m3fn)
B_f8 = (B_bf16 * B_s).to(torch.float8_e4m3fn)
# perform the matmul with scaled float8 inputs
C_bf16 = scaled_mm(A_f8, 1 / A_s, B_f8, 1 / B_s, ...)
We need a couple of things to make the latter work in PyTorch
 Float8 primitives (dtypes + matmul) in core, for easy development out of core
 Pertensor scaling, to achieve competitive accuracy
 Composability with systems such as autograd, torch.compile, inductor and distributed for competitive performance
Float8 primitives in core
We added the following primitives to PyTorch core to enable float8 modeling:
torch.float8_e4m3fn
andtorch.float8_e5m2
dtypes, matching the spec described in [2209.05433] FP8 Formats for Deep Learning.torch._scaled_mm
function, which wraps the cuBLAS float8 matmul routine and is about 2x faster than the bf16 mm on common LLaMa 70B shapes on an NVIDIA H100SXM GPU.
A deep dive into pertensor scaling
There are three common approaches to pertensor scaling: dynamic, delayed and static. Dynamic and delayed scaling are supported by float8_experimental today.
Dynamic scaling
Dynamic scaling calculates the scale justintime based on the current tensor. Pseudocode:
amax = abs(max(x))
scale = amax_to_scale(amax, torch.float8_e4m3fn)
x_fp8 = (x * scale).to(torch.float8_e4m3fn)
An advantage of this approach is accuracy (the scale is always based on the values of the current tensor) and simplicity (stateless, can be implemented with torch_dispatch overrides).
A drawback of this approach is performance. Calculating the amax requires a reduction to a single element and because the amax value is needed for the scale it is difficult to fully fuse this with the subsequent scaled cast.
Delayed scaling
Delayed scaling is a more fusion friendly algorithm for pertensor scaling. Pseudocode:
prev_amax = get_prev_amax(amax_history)
cur_scale = amax_to_scale(prev_amax, torch.float8_e4m3fn)
x_fp8 = (x * cur_scale).to(torch.float8_e4m3fn)
cur_amax = abs(max(x))
store_cur_amax_for_next_iteration(cur_amax)
Note that calculating cur_amax can now overlap with the scaled cast, leading to a potential performance improvement over delayed scaling.
A drawback of delayed scaling is implementation complexity. We need to track persistent state for each float8 cast op in the computational graph, which requires a model rewrite to implement cleanly. In this way, float8 training with delayed scaling is more akin to quantizationaware training than to lighterweight features such as model.to(bfloat16) or automated mixed precision.
Static scaling
Static scaling precalculates the scales of each tensor offline. Since this requires the weights to be frozen for good accuracy, this is usually an inferenceonly technique. In pseudocode:
scale = get_precalculated_scale(...)
x_fp8 = (x * scale).to(torch.float8_e4m3fn
The advantage is highest performance (the overhead of amax calculation is completely removed), and the drawbacks are statefulness (needs a model rewrite) and reduced accuracy (because all the scales are precomputed).
Composability with key PyTorch systems
This section details how float8 training interacts with key PyTorch systems, and the work done to enable this.
autograd
In the context of float8 training, for a tensor x we usually need x.dtype to be float8 but x.grad.dtype to be bfloat16. Autograd currently enforces x.dtype to equal x.grad.dtype for historical reasons. To get around this restriction we use Float8Tensor, which stores the raw data in float8 but advertises its dtype to autograd as bfloat16.
torch.compile
Scaling and casting tensors to float8 introduces overhead; we accept this overhead in eager mode to keep the simple and depend on torch.compile + inductor to recover performance. For example, LLaMa 7B training with float8 dynamic scaling has a speedup of 0.81 over bf16 in eager mode, and 1.22 with torch.compile. Please see the performance section later in this document for detailed speedups.
We are now fully torch.compile compatible for single GPU as well as 16bit allgather FSDP float8 code, due to the new torch.compile support for traceable tensor subclasses, tensor subclass + torch.autograd.Function, and inplace buffer mutations in the backward.
inductor
After we get a graph from torch.compile, we use inductor to generate kernels for amax and scaled cast fused into surrounding ops. We added inductor support for float8 dtypes, and optimized code generation to be performant for amax calculation, scaling and float8 cast necessary for Float8Linear.
distributed
Given a float8 matmul C_bf16 = A_fp8 @ B_fp8 and a distributed paradigm that communicates any of A or B across ranks, there are opportunities to do the communication in float8 to save communication latency by an upper bound of 50% over bf16. Below is an overview of how this applies to common distributed paradigms. Note that today, float8_experimental only composes with 16bit allgather FSDP. FSDP with float8 allgather and DTensor composability are planned in 2024.
FSDP
If the gemm compute is using float8 then doing the allgather for matmul weights in float8 is a free comm latency win. However, this is not possible to implement cleanly using today’s FSDP design. We are working on perparameter FSDP in 2024 which will support float8 allgather.
Float8 reductions are currently an open area of research and we are following the developments.
TensorParallel / SequenceParallel
In TP/SP (example from MegatronLM, https://arxiv.org/pdf/2205.05198.pdf, Figure 5), if the linear compute is using float8 then doing the all_gathers in both g and g_bar is a free latency win, and reducescatters are out of scope by the same reasoning as the FSDP case. float8_experimental currently supports float8 allgather over Fairscale’s RowParallelLinear and ColumnParallelLinear TP/SP implementation, but this is not compatible with torch.compile. Our plan for enabling torch.compile support is composing with DTensor (below).
DTensor
DTensor is a new way to write distributed modeling code in PyTorch without requiring model code changes. We plan to enable composing float8 with DTensor to support TP/SP + float8 + torch.compile.
Float8 design choices
Float8 dtypes are unscaled
torch.float8_e4m3fn and torch.float8_e5m2 are unscaled, with the expectation that the higher level modeling code will track the scales. Ops such as torch.scaled_mm which require scales take them as separate arguments. This allows for flexibility when implementing scaling in the modeling layer, such as extending scales from pertensor a higher granularity. We use tensor subclasses such as Float8Tensor for easy tracking of raw_data + scale in the modeling layer.
Use model rewrites to implement pertensor scaling
The current SOTA scaling strategy is delayed scaling; this requires stateful pertensor statistics collection for a subset of weights, activations, and gradients. A model rewrite is necessary to implement this cleanly; lighterweight approaches such as automated mixed precision are not expressive enough. Even for stateless scaling strategies such as delayed scaling, a model rewrite implementation allows them to easily be compared with stateful strategies.
The current model rewrite approach we are using in float8_experimental is module swaps. In the future, we may explore module hooks and graph capture + graph pass to cover more cases.
Eager debuggability, compiled performance
Pertensor scaling adds overhead via the calculation of abs(max(tensor)) and dtype casts. We keep our code simple in eager mode for easy debuggability and to minimize the need to write kernels. We depend on torch.compile + inductor to recover performance.
Distributed
We plan to compose with DTensor for our long term distributed support. Note that this functionality is not fully implemented yet.
State of float8 in PyTorch, as of 20240116

float8 primitives (dtypes, matmul) are landed to core

autograd, torch.compile and inductor support float8 scaling, casting and gemms

we open sourced float8_experimental: a library for float8 acceleration training with native PyTorch. Highlights:
 an easily hackable/debuggable codebase written in 4.5k LOC of Python
 supports delayed and dynamic scaling
 supports torch.compile
 supports FSDP with 16bit allgather
 FSDP with float8 allgather and TP/SP support are planned
 accuracy: no significant degradation from bf16 training on on LLaMa 7B and LLaMa 70B
 LLaMa 7B / 1 GPU pretraining: 1.22x speedup with a 1.07x peak memory increase over bf16. LLaMa 13B / 8 GPUs: 1.20x speedup with a 1.22x peak memory increase over bf16 For memory overhead, the results are better for single GPU (1.1x) compared to X GPUs (1.22x) because our current FSDP integration multiplies this memory overhead (we plan to address this).
float8_experimental performance on small scale LLaMa
Key things to note:
 The e2e speedup is bounded by the percentage of time spent in GEMMs.
 In eager mode, PyTorch’s float8 is slower than bf16. This is expected as we explicitly require torch.compile for performance.
 PyTorch float8’s dynamic and delayed scaling are close in performance with the current implementation. Performance work in inductor to speed up both dynamic and delayed scaling is planned.
float8_experimental memory usage on small scale LLaMa
The naive implementation of float8 training increases memory usage as we save extra float8 activations and weights for backward, which is what we have today. Improving this via allowing configurable recomputations of float8 casts in the backward is planned.
Key things to note:
 PyTorch float8 + single GPU: we observe an up to 1.11x increase in memory usage due to storing extra float8 weights and activations for the backward. We plan to add an option to let the user choose to recompute them instead.
 PyTorch float8 + multiGPU FSDP: we observe a significant memory regression which scales with the number of GPUs. This is happening because we are saving the unsharded float8 version of weight for the backward. We plan to fix this in the same manner as (2) and ensure this feature composes with FSDP’s unshard>free logic.
 Our TP/SP support does not work with torch.compile yet, which is why we stopped at LLaMa 13B for this note.
Upcoming work
Two major improvements are in the works:
 improved composability with distributed. Specifically, we plan to reduce peak memory usage when using float8 with FSDP, enable FSDP with 8bit allgather, and compose with DTensor to enable TP/SP.
 native support for float8 inference
Stay tuned for more updates, and please feel free to file issues in GitHub  pytorchlabs/float8_experimental: This repository contains the experimental PyTorch native float8 training UX.