FP8 datatype in PyTorch

Is there any plan to support FP8 as a datatype to PyTorch?

1 Like

there’s no immediate plan to support this in the quantization workflow. also don’t think there’s any immediate plan to support this as an unquantized type

1 Like

Would you accept an external contribution for it?

If so, is there any guideline about what that contribution should contain for this new data type to be accepted?

Note that there are two data representations (E4M3 and E5M2) for fp8. From a recent study ([2209.05433] FP8 Formats for Deep Learning), both representations might be needed to get good accuracy. When you mention FP8, are you talking about both data representation or one of them?

It is about both the FP8 formats. Recently, the Transformer Engine from NVidia (GitHub - NVIDIA/TransformerEngine: A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference.) added FP8 support. Is there any plan that this would get integrated inside PyTorch so that FP8 is available for other backends?

We recently discussed FP8 at a recent composability meeting; you can view the public minutes at Composability meeting notes - Google Docs

The summary is that, while it is a bit premature to add proper FP8 types to PyTorch, we are going to add some generic bits8/16/etc type to PyTorch so you can easily prototype FP8 in a tensor subclass without having to get core to actually add all of the necessary bits of support you need. Angela Yi is looking into adding this support!


A status report please, on the generic bits8/16/etc “uninterpreted” dtype support.