PyTorch and TensorFloat32

Context

TensorFloat32 (TF32) is a math mode introduced with NVIDIA’s Ampere GPUs. When enabled, it computes float32 GEMMs faster but with reduced numerical accuracy. For many programs this results in a significant speedup and negligible accuracy impact, but for some programs there is a noticeable and significant effect from the reduced accuracy.

Support for TensorFloat32 operations were added in PyTorch 1.7, and today PyTorch’s matrix multiplications and convolutions use TensorFloat32 on Ampere hardware by default. The behavior can be controlled with two global switches:

  • torch.backends.matmul.allow_tf32 and
  • torch.backends.cudnn.allow_tf32

Making TensorFloat32 the default math mode was based on its improved performance and NVIDIA’s demonstrating that a variety of deep learning networks would continue to be successfully trained. Unfortunately, this default means that some operations produce significantly less accurate results compared to CPUs and other NVIDIA hardware when running on Ampere hardware, and some programs, including deep learning networks, no longer work as expected.

PyTorch made several updates to accommodate this reduction in accuracy:

  • Fixes new tf32 failures in test_nn.py (#52871)
  • TF32 threshold twiddling for tests (#60209)
  • Increase some tolerances for tf32 for Conv3d tests (#60451)

And documented the accuracy discrepancy here, too.

Issues

Some users have expressed confusion and frustration with the new default math mode on Ampere hardware. Here’s a few highlights, but see the Appendix for more details:

  • At the operator-level there are simple reproductions of highly confusing results, like multiplying by an identity matrix not being an identity operation, or [[1.0]] @ [[1.0001]] resulting in 1 while [1.0] @ [1.0001] correctly results in 1.0001. These examples raise concerns about PyTorch’s core operator correctness on Ampere GPUs.
  • Inconsistent results between Ampere GPUs, the CPU, NumPy, and other NVIDIA hardware has been a source of confusion for users who are validating their programs against references or upgrading to Ampere.
  • Deep learning programs are affected, too.
    • Some machine learning models are working within a context where the results need to to conform to specific mathematical characteristics, for example: it is important for regression models used in approximating and accelerating physics models to provide answers that follow relevant conservation laws, models in chemistry need to provide results that are stable with respect to geometrical translation and rotation. In these cases the algorithms use constant tensors with very specific mathematical meanings which have been seen to give significantly degraded results when run with TF32 enabled (example: github issue #69921).
    • Operations performed when data loading may need to be performed with extremely high fidelity
    • Real deep learning networks in the community, including NVIDIA’s own reference implementation of StyleGAN2-ADA for PyTorch, disable TensorFloat32 GEMMs.

While the TensorFloat32 math mode default appears to be working for (and benefiting) a majority of programs run on Ampere (although many PyTorch users do not yet use Ampere hardware), the volume of issues, observed user confusion, and challenge of understanding when the TF32 math mode is an issue suggest that making TF32 the default math mode is not working as we had hoped.

Mitigations

The issues cited above will become more prevalent if not mitigated as more and more of the PyTorch community adopts newer NVIDIA hardware. Our goals with any mitigation are:

  • to continue making it easy for users to get the most of their hardware
  • to continue supporting the state-of-the-art in mixed precision computation
  • to keep users in control by ensuring our operators deliver the expected results and that upgrading to newer hardware isn’t disruptive

PyTorch’s Automated Mixed Precision (AMP) module seems like an effective guide for how to update our thinking around the TF32 math mode for GEMMs. While not on by default, AMP is a popular module that users can easily opt into. It provides a tremendous amount of clarity and control, and is credited for the speedups it provides. It’s also straightforward to diagnose when there are issues with it, as users understand when it’s enabled and when it’s not. TF32 math mode for GEMMs seems like a comparable performance optimization.

As such, we propose changing the default math mode on Ampere devices to be consistent with other CUDA devices and requiring users opt-in to the TensorFloat32 math mode. This is a disruptive change for existing Ampere users, but we can minimize that disruption and continue helping users adopt mixed precision operations by:

  • Update our performance profiling tools to recommend trying AMP and the TF32 math mode for GEMMs
  • Update our documentation and guidance around the TF32 math mode, including identifying which operators are likely to see improved performance when the TF32 math mode is enabled
  • Working with partners – like Huggingface and PyTorch Lightning – to pick the correct TF32 math mode defaults for their frameworks, and to expose mixed precision options that provide clarity to users.

Some alternatives to the above, like trying to set the math mode per operator, are discussed in the issue RFC: Should matmuls use tf32 by default? (#67384). That idea is particularly challenging to adopt because:

  • Technically, the TF32 math mode is implemented as a global switch that we cannot change on a per-op (or per-thread) basis without a performance penalty.
  • Our investigations into enabling TF32 contextually suggest it would be technically complicated, incur a performance penalty, and reduce UX clarity. We think it’s better to educate users about how to use mixed precision in their programs.

Note that we are not proposing changing the default math mode for convolutions. Very few – if any – of the issues we’ve reviewed suggest a problem with TF32 as the default math mode for convolutions.

Appendix – TF32 User Issues

Example issues:

  • torch.matmul() gives rounded results for 30xx cards (#53635)
    • “I will close the issue then, though I would say that it is quite a questionable feature to be enabled by default. In particular, because now it leads to different results of the same code running with Ampere or not Ampere GPU.”
  • simple matrix multiplication yields wrong result on Ampere (3080) (#55355)
    • “Loading to .cpu() gets the correct result. On two other machines with 2080 and Titan X there’s no such problem.”
  • The computation results on RTX 3090 are totally different from others (#58434)
    • “I compared the computation results from [an RTX 3080, a GTX 1080, CPU, and NumPy] and results from the last 3 were the same, whereas that from RTX 3090 were totally different.”
  • Incorrect result with CUDA (#58688)
    • “Matrix multiplication with the unit matrix does not reproduce input precisely on CUDA. On CPU result is correct.”
  • The matrix multiplication operator can’t get correct results on 3090 !! (#61890) (+2)
    • “because [I’m multiplying by an identity matrix], the output should be unchanged.”
    • In this case the user is doing a geometric transformation (rotation)
  • Precision issue in matrix-vector multiplication (#66212)
  • Numerical inconsistencies on GPU when computing A.T@B vs (B.T@A).T (#67185)
    • “When running on a CPU, the difference between a.T@b vs (b.T@a).T is small enough to be accounted for by precision error. However, on a gpu, the error is huge. The error on a gpu disappears when I slice off only small fraction of the b matrix. Also, casting everything to double fixes the issue.”
    • “I’d like to argue though whether turning this option on by default on Amperes without any visible indication is a good thing. I was always under the impression that tf32/fp16/… lower-precision solutions were explicit so that bugs like mine don’t happen. I also see the benefits of turning it on by default (it will be much faster without considerable downsides for a typical user).”

PyTorch forum posts:

  • Pytorch matmul is inconsistent on GPU (thread)
  • Considerable absolute error in torch.matmul (thread)
  • CPU/GPU results inconsistent with matrix multiplication (thread)
  • Got Different Inference Conv2d Results on Different GPU Machine (thread)
  • Bug? matmul seems to cast to float16 internally (thread)
  • L-BFGS optimizer doesn’t work properly on CUDA (thread)

Examples from other sites:

  • Using Pytorch model trained on RTX2080 on RTX3060 (Stack Overflow)
    • “The flag torch.backends.cuda.matmul.allow_tf32 = false needs to be set, to provide a stable execution of the model of a different architecture.”
  • improve test F1 score from 88 to 96 via changing GPUs? (Twitter)

Examples from deep learning code:

Sourcegraph’s code search also highlights some concerning comments, like in this code snippet from “Deep Diffusion Models for Robust Channel Estimation” from UT Austin…

# !!! Always !!! Otherwise major headache on RTX 3090 cards

torch.backends.cuda.matmul.allow_tf32 = False

… as well as neural network implementations disabling TensorFloat32 explicitly, including NVIDIA’s own reference implementation of StyleGAN2-ADA for PyTorch.

There have been additional questions from within Meta (“this sort of magic is exactly what drives people crazy with TF/Keras”), including from teams doing performance analysis and struggling to explain supposedly fp32 operations performing above peak fp32 performance.

4 Likes

Second this! We often do linalg solves / compute matrix decompositions with ill-conditioned matrices, and TF32 pretty much always breaks this, and we always have to remember to turn off tf32 computations depending on which hardware this is run on (which makes it very cumbersome to write hardware-agnostic code).

Update: in consultation with our colleagues at NVIDIA we will be changing the default value of torch.backends.matmul.allow_tf32 to False. This is a disruptive change, and we will minimize that disruption by updating our documentation and profiling tools to recommend users try enabling torch.backends.matmul.allow_tf32 to improve performance when appropriate.

Current TensorFloat32 matmul users will have to set torch.backends.matmul.allow_tf32=True explicitly to preserve their performance improvements.

This change was principally motivated by wanting to provide users with clarity about the math mode(s) they’re using. More details about the timing of this change and how we’re mitigating its impact will be available soon.

1 Like

fyi Add high level control of fp32 matmul precision; disable TF32 for matmuls by default by eqy · Pull Request #76509 · pytorch/pytorch · GitHub has landed and changes the default to the highest precision fp32 matmul math mode. It also introduces a new device agnostic way for controlling this mode.

I just installed 1.12 (pip; Windows; CUDA 11.6) and I don’t see torch.backends.matmul.allow_tf32:

>>> import torch
>>> torch.__version__
'1.12.0+cu116'
>>> torch.backends.matmul.allow_tf32
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: module 'torch.backends' has no attribute 'matmul'
>>> torch.backends.cudnn.allow_tf32
True
>>> dir(torch.backends)
['ContextProp', 'PropModule', '__allow_nonbracketed_mutation', '__allow_nonbracketed_mutation_flag', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'contextmanager', 'cuda', 'cudnn', 'disable_global_flags', 'flags_frozen', 'mkl', 'mkldnn', 'mps', 'openmp', 'quantized', 'types']

https://pytorch.org/docs/master/backends.html#torch.backends.cuda.torch.backends.cuda.matmul.allow_tf32

I think it’s torch.backends.cuda.matmul.

1 Like