Hi PyTorch Community!
This post is a supplementary material to our soon to be published “What Every User Should Know About Mixed Precision Training in PyTorch” blog post. We hope this would help you use mixed precision even more in PyTorch!
More In-Depth Details of Floating Point Precision
Floating-point (FP) formats consist of a sign bit, exponent bits and mantissa bits. Using these bits, a real number is represented as: (-1)sign × mantissa × 2exponent.
Figure 1 summarizes the bit partitioning of the most commonly used floating-point formats in PyTorch. The exponent and mantissa bits dictate how accurately we can represent a real number in the given format. Mixed precision computation means using an FP type with a smaller bit representation and less accuracy than the typical or reference type (detailed information available in Cherubin et al.).
Figure 1: Bit Partitioning of Floating-Point Formats. TF32 Bringing the Best of FP32 and FP16. (Source)
The range bits determine how large of or small of a number can be expressed by the format. The precision decides how close the number can be represented in the format to a true value. For an accessible introduction to these axes refer to this excellent article written by Bartosz Ciechanowski.
All floating point operations have intrinsic error. There is no one-size-fits-all format for every application, and depending on your application there are multiple formats that may provide the numerical properties required. Specifically for Deep Learning, as mentioned above, a plethora of use cases train faster with mixed precision without impact on model accuracy. Figure 2 shows that training BERT with two mixed precision options (FP16 mixed precision and TF32) yields consistent network convergence and results in near identical accuracies.
Figure 2: Error-Tolerant Behavior of Deep Learning Training (Source).
On the other hand, it is possible to evaluate the numerical properties which can be important for many non-DL workloads. Figure 3 compares the closeness of torch.matmul computed at different mixed precision formats (with torch.randn inputs). We can see the accuracy of TF32 for matmul can depend on the availability of a kernel (selected by cuBLAS) for a specific k dimension. cuBLAS has many TF32 enabled kernels, however, occasionally it cannot find one for a specific input size and can select an IEEE FP32 kernel instead. This can clearly be seen happening when we plot the absolute max difference of the kernels with the same inputs processed in double precision.
Figure 3: Error-Prone Behavior of torch.matmul. The lines compute the absolute max difference of torch.matmul computed in a reduced precision format — BF16 (green), FP16 (blue), TF32 (red), FP32 (yellow) — from its value in a reference format (FP64), signifying the closeness of the values in the same computation.
This analysis assumes the inputs follow a standard normal distribution which many applications don’t closely follow. Therefore it’s important to keep this in mind when thinking about precision of different formats as different input distributions will result in variations in accuracy. However, given an input distribution, error is expected to strongly correlate with the dot product dimension of the Matrix Multiplies (similar parallels can be drawn to convolutions which scale with the product of the channels and window dimensions of the convolution).
Intrinsic Errors in Floating Point Computations
Nathan el al. provides a good summary of the errors that are intrinsic to floating point computation. These are round-off, differences in addend exponents, cancellation, and near overflow/underflow errors. For a demonstration of these errors, refer to this gist and notebook. For a more thorough treatment of this topic, refer to Higham. Knowledge of these errors is helpful when deciding on the best floating point precision mode for use cases, especially those outside of Deep Learning.
Error Propagation in Computational Graphs
Even though mixed precision computations have intrinsic errors (just as any Floating Point calculations do), users can orchestrate these errors such that the end result is in an acceptable range for the application. Indeed this is the motivation behind TF32 mode in tensor cores when used in DL workloads. For TF32 mode, tensor cores downcast the FP32 inputs to TF32 format which incurs round-off errors. However, the multiply-accumulate (MAC) is done in IEEE FP32 precision which reduces the propagation of round-off errors that we would see if the MAC was done in a lower precision instead. In a similar vein, Micikevicius et al uses a technique to multiply loss by a constant factor that scales FP16 gradients to fit FP16 maximally in its range to maximize precision while preventing overflow. DL workloads can be effectively orchestrated to take advantage of mixed precision training in this way to prevent loss of model performance, which is exactly what AMP does.
For non-DL workloads that are sensitive to floating point errors, one should be aware of the finer details of how errors propagate. Sanchez-Stern et al. specifically shows intermediate floating point errors can hide in a graph of computations and are frequently hard to pinpoint. They present three situations:
- Non-composition error: intermediate errors can accumulate and output an inaccurate result or they can cancel to provide an accurate result.
- Non-local error: the symptom of the error can show in one function but the root cause can be in another lower level function.
- Non-uniform error: inputs from different distributions can cause significantly different amounts of error.
In fact, non-local errors can occur from bugs in the math implementation of operations. Math libraries have many functions to compute the same operation over variations in input sizes. Variations in these kernels produce variations in accuracy, so it’s not necessarily just a function of precision and input distribution, but also the details of how the result is computed. One common example of this is how tree reductions (frequently used on GPUs) are typically much more accurate than serial reductions (frequently used on processors with less parallelism available).