FMAs (and softmax (and floating point)) considered harmful

def f(x, scale):
    max_scaled = x * scale
    return torch.exp(max_scaled - x * scale)

In this case, f(x, scale) = 1, but torch.compile(f)(x, scale) = inf. Why is that?

Go read this comment (https://github.com/pytorch/pytorch/issues/121558…) to learn more! But basically, the general thing is that we’re computing something like exp(f(x) - f(x)). As long as f(x) = f(x), things are good! However, if the computation of f(x) diverges for any reason, then this can lead to arbitrarily bad results.

Go read this comment (https://github.com/pytorch/pytorch/issues/121558…) to learn more! But basically, the general thing is that we’re computing something like exp(f(x) - f(x)). As long as f(x) = f(x), things are good! However, if the computation of f(x) diverges for any reason, then this can lead to arbitrarily bad results.

This is unfortunately quite challenging, as generally speaking, we have not guaranteed that our numerics are identical with eager. In particular, it’s possible for us to do less rounding than eager.

However, even with fma off, this can lead to unboundedly worse numerics for user code (see https://github.com/pytorch/pytorch/issues/122260…).

Here is also a repro with higher numeric error for user code. This is an example of how things can go badly in combination with inductor heuristics like rematerialization.

There are 3 issues that go on here (this is the code: https://pastebin.com/80pbhpLt)

  1. We choose to unroll sum/amax, preventing fusion with the downstream div_1.

  2. We choose to rematerialize most of the pointwise ops (div, sub, and exp) (i.e. we compute it in both the first kernel and the second kernel). This is also suboptimal, since we load both inputs in two kernels for no reason.

  3. In the second kernel, we are computing ((arg0_1 / arg1_1) - buf0).exp() / buf1. Here, because buf0 is supposed to be (arg0_1 / arg1_1).amax(). So, there is some element for which this should be ((arg0_1 / arg1_1) - (arg0_1 / arg1_1)).exp(), which is equal to 1. However, because of fma, this is rewritten as fma(arg0_1, (1/arg_1), -buf0), which does not equal 1.

TL;DR: FMA bad. I think we should turn it off for most kernels. Today, we cannot guarantee that Inductor code always results in better numerics than user code, particularly in the presence of softmax. Turning off fma will prevent the worst from occurring with our version of softmax (which is run in fp32).

2 Likes