tldr; This post explains what adding a new dtype to PyTorch core means, the criteria of adding a new dtype to PyTorch core and the official recommendation of how to support new “secondary dtypes” use cases like (group) quantized uint4 etc. that will be backed up by native dtypes like uint4.
Context
We had some recent discussions for adding int4 dtypes in PyTorch with our customers, main sentiment from our users is that it’s becoming pretty popular and multiple people/teams has requested this both internally at meta and in OSS. However, all existing cases can be served and it’s not clear the exact benefit of adding a new dtype in PyTorch, and we do not have a criteria for features that need to be supported for a native dtype in PyTorch. This post plans to answer some of the common questions around how we think about adding new dtypes to PyTorch (the same reasoning applies to other core constructs like device as well).
1. Current native dtypes in PyTorch
See: pytorch/c10/core/ScalarType.h at main · pytorch/pytorch · GitHub and torch.Tensor — PyTorch 2.1 documentation
Here are the existing dtypes for torch.Tensor
-
float32, float64, float16, bfloat16, float8_e5m2, float8_e4m3fn
-
complex32, complex64, complex128 (historical reasons)
-
uint1 to uint7, uint8, int8, int16, int32, int64, uint16, uint32, uint64
- uint1 to uint7, uint16, 32, 64 have limited operator support; the dtypes exist for interoperability and ease of integration with PT2, but we don’t plan to add full eager kernel coverage for them.
-
bool
-
quint8, qint8, qint32, quint4x2, quint2x4
- We plan to deprecate these because we are moving to a more sustainable and scalable approach of representing quantized tensors, by storing quantization parameters in the operators instead of Tensor itself and use the native PyTorch tensors.
-
Uninterpreted bits types (bits8, bits4x2, bits2x4, bits1x8)
- These were added before to hold data without touching semantics, we plan to deprecate these in favor of uint1 to uint8
Some notes
-
We have some requests for adding uint16 from quantization customers, but the previous recommendation has been please use int32 and quant_min/quant_max to simulate it, but now we have uint16 in PyTorch core with some barebone support, it should be enough for quantization use case
-
Not all dtypes have the same support for all operators/features, (e.g. complex dtypes) but it’s not clear how well each dtype is supported
2. What are the features that need to be supported for a native PyTorch dtype?
-
Basic Tensor support
-
torch.{new_dtype}
-
Type promotion
- Not all dtypes have to support type promotion; for example, bits dtypes don’t support promotion because it’s not a meaningful concept for them, they don’t represent some abstract notion of integers.
-
Device support list
- CPU, CUDA, …
-
Tensor creation (factory functions: torch.empty, torch.zeros, torch.ones etc.)
- In particular, it must be meaningful to say something like torch.tensor([0, 1], dtype=my_dtype). If your “dtype” requires extra metadata, it’s not a dtype!
-
Tensor shape operators (view, select, slice, reshape, etc.)
-
Operator support list (conv2d, linear, etc.)
- For extremely low precision dtypes (e.g., 8 bit and below), we expect most “traditional” eager mode operations, like multiplications, to be supported neither in eager nor compile. This is because naive operations (e.g. linear(uint4_tensor, uint4_tensor)) without scale are typically not useful; instead, you need some more complicated, fused operation that accounts for scale (e.g. fp16_act_int4_weight(fp32_act, int4_weight, weight_scales)). The expectation is that we will support basic functionality (casting) and fused operations in eager.
-
Serialization
-
And many more…
-
-
Feature/Composability (might have some overlap with others)
-
torch.export with a Tensor of the given dtype
- Note: operations with some of the dtypes (e.g. torch.uint4) will be desugarded into the underlying implementation (using torch.uint8) by default during export, we may provide some special C++ ops to make it easier for pattern matching
-
torch.compile (inductor/triton) with a Tensor of the given dtype
-
Other distributed things: RPC, DDP, etc.
-
Autocast support, if it is a floating point type (or when it applies, e.g. may not apply to float8)
-
nn / optimizer
-
Profiler
-
Visualization
-
And many more…
-
-
Other Costs
-
binary size
-
compilation time
-
test time
-
3. What is the criteria for adding a new dtype to PyTorch core?
-
We can predict future wide usage of a dtype, if the dtype will be supported in Silicon on major accelerator hardware we support (e.g., if the next generation of NVIDIA GPUs is going to natively support a format, this is highly predictive of it being widely used.)
-
The dtype must be meaningful without any extra metadata. For example, fp8 has a well defined interpretation without a scaling factor, and so can be a float. any4 is only defined with a 2^4 lookup table mapping int values to float values, and so cannot be a dtype in the traditional sense. Another way to think about it: torch.tensor([0], dtype=your_dtype) must be meaningful; if it is not meaningful, you don’t have a dtype.
4. What is enabled by having a native PyTorch dtype
-
Reduced friction for user
-
Simplified packaging and dependency
-
Better native integration with hardwares
5. Should we define an official “secondary dtype” path that doesn’t meet criteria in 3.?
There are use cases like group quantized uint4, any4, mx that need extra metadata to make sense and we might see more of these coming as quantization becomes increasingly popular among LLM, but these did not meet the criteria in 3. The official extension point we recommend for these use cases will be Tensor subclass, or look into adding extension points to the existing systems, for example, if we want to build a quantized uint4 Tensor based on the native torch.uint4 dtype, we could do the following:
class QuantizedUInt4Tensor(torch.Tensor):
...
As long as you are relying on supported operators/features on the underlying dtype and implemented the Tensor subclass correctly, we expect it to work automatically with all existing systems like dynamo, torch.compile etc. We’ll have a separate post about official support for the new dtypes like uint4 after we validate them with some use cases. For now please wait for these docs about feature support for each dtype (e.g. uint4) before trying to implement “secondary dtypes” backed up by these native dtypes.
A more complete example can be found here
Appendix: Some case studies
-
qint8 - despite being in ScalarType, this is NOT a good dtype, because it requires extra metadata to be interpreted. Stuff like torch.ones(N, dtype=torch.quint8) doesn’t work! We should not have added it.
-
float8_e5m2 and other variants - these are OK to add because they have a meaningful interpretation without extra metadata, and they have Silicon support in NVIDIA H100. Unlike classic dtypes, they have limited operator support.
-
uint16, uint32, uint64 - these are extremely well known (e.g., Numpy supports them) and we have added support for them in PyTorch. However, they have limited operator support for binary size reasons (but we expect PT2 to be able to deal with the coverage gap.
-
uint1,2,3,4,5,6,7 - these are OK to add because sub-byte dtypes are reasonably popular and it is useful to have native support for sub-byte size in PyTorch core to ease arithmetic. These are quite difficult to implement in C++ so we will only have Python support in the mid term. You’ll use these as the basis for other sub-byte formats. Sub-byte is a bit difficult to implement.
-
mx4 this is NOT a good dtype as the microexponents need to be stored in some way so that you can interpret the numbers. OK to have f8_e8m0 as a type to represent the exponents though