def f(x, scale):
max_scaled = x * scale
return torch.exp(max_scaled - x * scale)
In this case, f(x, scale) = 1, but torch.compile(f)(x, scale) = inf. Why is that?
Go read this comment (https://github.com/pytorch/pytorch/issues/121558…) to learn more! But basically, the general thing is that we’re computing something like exp(f(x) - f(x))
. As long as f(x) = f(x)
, things are good! However, if the computation of f(x)
diverges for any reason, then this can lead to arbitrarily bad results.
Go read this comment (https://github.com/pytorch/pytorch/issues/121558…) to learn more! But basically, the general thing is that we’re computing something like exp(f(x) - f(x))
. As long as f(x) = f(x)
, things are good! However, if the computation of f(x)
diverges for any reason, then this can lead to arbitrarily bad results.
This is unfortunately quite challenging, as generally speaking, we have not guaranteed that our numerics are identical with eager. In particular, it’s possible for us to do less rounding than eager.
However, even with fma off, this can lead to unboundedly worse numerics for user code (see https://github.com/pytorch/pytorch/issues/122260…).
Here is also a repro with higher numeric error for user code. This is an example of how things can go badly in combination with inductor heuristics like rematerialization.
There are 3 issues that go on here (this is the code: https://pastebin.com/80pbhpLt)
-
We choose to unroll
sum/amax
, preventing fusion with the downstreamdiv_1
. -
We choose to rematerialize most of the pointwise ops (
div
,sub
, andexp
) (i.e. we compute it in both the first kernel and the second kernel). This is also suboptimal, since we load both inputs in two kernels for no reason. -
In the second kernel, we are computing
((arg0_1 / arg1_1) - buf0).exp() / buf1
. Here, becausebuf0
is supposed to be(arg0_1 / arg1_1).amax()
. So, there is some element for which this should be((arg0_1 / arg1_1) - (arg0_1 / arg1_1)).exp()
, which is equal to 1. However, because of fma, this is rewritten asfma(arg0_1, (1/arg_1), -buf0)
, which does not equal 1.
TL;DR: FMA bad. I think we should turn it off for most kernels. Today, we cannot guarantee that Inductor code always results in better numerics than user code, particularly in the presence of softmax. Turning off fma will prevent the worst from occurring with our version of softmax (which is run in fp32).