How to approach targets that don't support i64/f64?

Hi folks, in Torch-MLIR a common request we get is to better support targets where i64 or f64 are undesirable (either blanket unsupported, or expensive to emulate and the user “knows” that 64-bit isn’t needed).

Unfortunately, a lot of things in PyTorch are i64 or f64 and it is totally possible for a user to write a valid program that legitimately needs that support. Has there been any thought put into how users can communicate their needs regarding i64/f64? In particular to compilers, though consistency with eager mode is of course important.

I think there are basically 4 cases:

  1. Scalar (Python) floats. These are specced as 64-bit floats.
    • Thoughts: I don’t have specific insight here, but there are likely many scenarios where 64-bit is overkill (e.g. when deploying to a microcontroller). Some scientific codes might care about some factor being very precise though.
  2. Scalar (Python) ints. In Python these are arbitrary precision. In TorchScript they are truncated to 64-bit and in general PyTorch C++ code is often in terms of int64_t.
    • Thoughts: Often these are used to calculate view sizes and such. With large language models in the 100’s of GB these days, we cannot arbitrarily use 32-bit indexing though (though perhaps individual tensor dimensions remain in 32-bit range?).
  3. Tensors with 64-bit floating point numbers.
    • Thoughts: PyTorch defaults to f32, so if a user asks for f64 they probably actually want the extra precision (?).
  4. Tensors with 64-bit integers. This is probably most common for embedding indices.
    • Thoughts: Most embeddings are likely OK to index with 32-bit indices, but they seem to be getting larger and larger, and it is not out of the question to need 64-bit indices there (anybody have a specific datapoint?).

For reference, the Torch-MLIR issue is here: Find a better solution for backends that don't support i64/f64 · Issue #1615 · llvm/torch-mlir · GitHub

2 Likes

If I had infinite engineering time, this is probably how I would set things up:

  1. Known constant int/floats can transparently decay to 32-bit if they are exactly representable at that precision. You can also provide a knob to unsoundly round doubles to float.
  2. For indexing tensors, I would have a compiler analysis that identifies indexing operations. When an operator takes an indexing tensor and the tensor it is indexing into fits with 32-bit indexing, it can convert the index tensor in 32-bit and propagate this as far back as possible.
  3. I agree that explicit use of f64 should be respected and give an error. You can of course unsoundly pretend all f64 are f32 but this probably won’t work so well, since the user had to go out of their way to ask for f64.
1 Like