Inconsistent precision between PyTorch's built-in backends for the same device

I see an inconsistent approach to precision in matmul, and I’d like to understand why.

Consider this script:

import torch

DEVICE = "cpu"
device = torch.device(DEVICE)

def matmul_outfp32(x, y, bias):
    out = torch.matmul(x, y, out=torch.empty(32, 7, 1, 1, dtype=torch.float32, device=device))
    out = out + bias
    return out

x = torch.randn(32, 7, 1, 8).to(torch.bfloat16).to(device)
y = torch.randn(32, 7, 8, 1).to(torch.bfloat16).to(device)
bias = torch.randn(32, 7, 1, 1).to(torch.bfloat16).to(device)

func = torch.compile(matmul_outfp32, backend="aot_eager")

out = func(x, y, bias)
print(out.dtype)

When you run it with aot_eager (or, if you’re using a custom backend, with aot_autograd), you get a graph like this:

===== Forward graph 0 =====
/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
   def forward(self, arg0_1: "bf16[32, 7, 8, 1][56, 8, 1, 1]cpu", arg1_1: "bf16[32, 7, 1, 8][56, 8, 8, 1]cpu", arg2_1: "bf16[32, 7, 1, 1][7, 1, 1, 1]cpu"):
        # File: t.py:7 in matmul_outfp32, code: out = torch.matmul(x, y, out=torch.empty(32, 7, 1, 1, dtype=torch.float32, device=device))
       empty: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.empty.memory_format([32, 7, 1, 1], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
       expand: "bf16[32, 7, 1, 8][56, 8, 8, 1]cpu" = torch.ops.aten.expand.default(arg1_1, [32, 7, 1, 8]);  arg1_1 = None
       view: "bf16[224, 1, 8][8, 8, 1]cpu" = torch.ops.aten.view.default(expand, [224, 1, 8]);  expand = None
       expand_1: "bf16[32, 7, 8, 1][56, 8, 1, 1]cpu" = torch.ops.aten.expand.default(arg0_1, [32, 7, 8, 1]);  arg0_1 = None
       view_1: "bf16[224, 8, 1][8, 1, 1]cpu" = torch.ops.aten.view.default(expand_1, [224, 8, 1]);  expand_1 = None
       bmm: "bf16[224, 1, 1][1, 1, 1]cpu" = torch.ops.aten.bmm.default(view, view_1);  view = view_1 = None
       view_2: "bf16[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.view.default(bmm, [32, 7, 1, 1]);  bmm = None
       copy: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.copy.default(empty, view_2);  empty = view_2 = None
       
        # File: t.py:8 in matmul_outfp32, code: out = out + bias
       add: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.add.Tensor(copy, arg2_1);  copy = arg2_1 = None
       return (add,)

If you change the backend to inductor, it appears to promote inputs to f32:

===== Forward graph 0 =====
/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
   def forward(self, arg0_1: "bf16[32, 7, 8, 1][56, 8, 1, 1]cpu", arg1_1: "bf16[32, 7, 1, 8][56, 8, 8, 1]cpu", arg2_1: "bf16[32, 7, 1, 1][7, 1, 1, 1]cpu"):
        # File: t.py:7 in matmul_outfp32, code: out = torch.matmul(x, y, out=torch.empty(32, 7, 1, 1, dtype=torch.float32, device=device))
       empty: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.empty.memory_format([32, 7, 1, 1], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
       expand: "bf16[32, 7, 1, 8][56, 8, 8, 1]cpu" = torch.ops.aten.expand.default(arg1_1, [32, 7, 1, 8]);  arg1_1 = None
       view: "bf16[224, 1, 8][8, 8, 1]cpu" = torch.ops.aten.view.default(expand, [224, 1, 8]);  expand = None
       expand_1: "bf16[32, 7, 8, 1][56, 8, 1, 1]cpu" = torch.ops.aten.expand.default(arg0_1, [32, 7, 8, 1]);  arg0_1 = None
       view_1: "bf16[224, 8, 1][8, 1, 1]cpu" = torch.ops.aten.view.default(expand_1, [224, 8, 1]);  expand_1 = None
       convert_element_type: "f32[224, 1, 8][8, 8, 1]cpu" = torch.ops.prims.convert_element_type.default(view, torch.float32);  view = None
       convert_element_type_1: "f32[224, 8, 1][8, 1, 1]cpu" = torch.ops.prims.convert_element_type.default(view_1, torch.float32);  view_1 = None
       squeeze: "f32[224, 8][8, 1]cpu" = torch.ops.aten.squeeze.dim(convert_element_type, 1);  convert_element_type = None
       squeeze_1: "f32[224, 8][8, 1]cpu" = torch.ops.aten.squeeze.dim(convert_element_type_1, -1);  convert_element_type_1 = None
       mul: "f32[224, 8][8, 1]cpu" = torch.ops.aten.mul.Tensor(squeeze, squeeze_1);  squeeze = squeeze_1 = None
       sum_1: "f32[224, 1][1, 1]cpu" = torch.ops.aten.sum.dim_IntList(mul, [1], True);  mul = None
       unsqueeze: "f32[224, 1, 1][1, 1, 1]cpu" = torch.ops.aten.unsqueeze.default(sum_1, 1);  sum_1 = None
       convert_element_type_2: "bf16[224, 1, 1][1, 1, 1]cpu" = torch.ops.prims.convert_element_type.default(unsqueeze, torch.bfloat16);  unsqueeze = None
       view_2: "bf16[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.view.default(convert_element_type_2, [32, 7, 1, 1]);  convert_element_type_2 = None
       copy: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.copy.default(empty, view_2);  empty = view_2 = None
       
        # File: t.py:8 in matmul_outfp32, code: out = out + bias
       add: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.add.Tensor(copy, arg2_1);  copy = arg2_1 = None
       return (add,)

Additionally, in eager mode, bmm_out is used, which also runs in f32 precision.

So, for the same inputs, the computation is performed at different precisions:

  • eagerf32 (uses bmm_out with f16 inputs and f32 outputs)
  • inductorf32 (promotes inputs to f32 before matmul)
  • aot_eagerf16 (computes in f16, then casts the result to f32)

I suspect that aot_eager could also use bmm_out, but it determines that the empty tensor isn’t really an input and therefore avoids using the in-place operation.

Should it behave this way? The torch.matmul documentation does not specify any precision guarantees, and PyTorch’s built-in implementations are inconsistent, so it’s unclear which behavior is preferred.

1 Like