I see an inconsistent approach to precision in matmul
, and I’d like to understand why.
Consider this script:
import torch
DEVICE = "cpu"
device = torch.device(DEVICE)
def matmul_outfp32(x, y, bias):
out = torch.matmul(x, y, out=torch.empty(32, 7, 1, 1, dtype=torch.float32, device=device))
out = out + bias
return out
x = torch.randn(32, 7, 1, 8).to(torch.bfloat16).to(device)
y = torch.randn(32, 7, 8, 1).to(torch.bfloat16).to(device)
bias = torch.randn(32, 7, 1, 1).to(torch.bfloat16).to(device)
func = torch.compile(matmul_outfp32, backend="aot_eager")
out = func(x, y, bias)
print(out.dtype)
When you run it with aot_eager
(or, if you’re using a custom backend, with aot_autograd
), you get a graph like this:
===== Forward graph 0 =====
/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "bf16[32, 7, 8, 1][56, 8, 1, 1]cpu", arg1_1: "bf16[32, 7, 1, 8][56, 8, 8, 1]cpu", arg2_1: "bf16[32, 7, 1, 1][7, 1, 1, 1]cpu"):
# File: t.py:7 in matmul_outfp32, code: out = torch.matmul(x, y, out=torch.empty(32, 7, 1, 1, dtype=torch.float32, device=device))
empty: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.empty.memory_format([32, 7, 1, 1], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
expand: "bf16[32, 7, 1, 8][56, 8, 8, 1]cpu" = torch.ops.aten.expand.default(arg1_1, [32, 7, 1, 8]); arg1_1 = None
view: "bf16[224, 1, 8][8, 8, 1]cpu" = torch.ops.aten.view.default(expand, [224, 1, 8]); expand = None
expand_1: "bf16[32, 7, 8, 1][56, 8, 1, 1]cpu" = torch.ops.aten.expand.default(arg0_1, [32, 7, 8, 1]); arg0_1 = None
view_1: "bf16[224, 8, 1][8, 1, 1]cpu" = torch.ops.aten.view.default(expand_1, [224, 8, 1]); expand_1 = None
bmm: "bf16[224, 1, 1][1, 1, 1]cpu" = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None
view_2: "bf16[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.view.default(bmm, [32, 7, 1, 1]); bmm = None
copy: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.copy.default(empty, view_2); empty = view_2 = None
# File: t.py:8 in matmul_outfp32, code: out = out + bias
add: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.add.Tensor(copy, arg2_1); copy = arg2_1 = None
return (add,)
If you change the backend to inductor
, it appears to promote inputs to f32
:
===== Forward graph 0 =====
/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "bf16[32, 7, 8, 1][56, 8, 1, 1]cpu", arg1_1: "bf16[32, 7, 1, 8][56, 8, 8, 1]cpu", arg2_1: "bf16[32, 7, 1, 1][7, 1, 1, 1]cpu"):
# File: t.py:7 in matmul_outfp32, code: out = torch.matmul(x, y, out=torch.empty(32, 7, 1, 1, dtype=torch.float32, device=device))
empty: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.empty.memory_format([32, 7, 1, 1], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
expand: "bf16[32, 7, 1, 8][56, 8, 8, 1]cpu" = torch.ops.aten.expand.default(arg1_1, [32, 7, 1, 8]); arg1_1 = None
view: "bf16[224, 1, 8][8, 8, 1]cpu" = torch.ops.aten.view.default(expand, [224, 1, 8]); expand = None
expand_1: "bf16[32, 7, 8, 1][56, 8, 1, 1]cpu" = torch.ops.aten.expand.default(arg0_1, [32, 7, 8, 1]); arg0_1 = None
view_1: "bf16[224, 8, 1][8, 1, 1]cpu" = torch.ops.aten.view.default(expand_1, [224, 8, 1]); expand_1 = None
convert_element_type: "f32[224, 1, 8][8, 8, 1]cpu" = torch.ops.prims.convert_element_type.default(view, torch.float32); view = None
convert_element_type_1: "f32[224, 8, 1][8, 1, 1]cpu" = torch.ops.prims.convert_element_type.default(view_1, torch.float32); view_1 = None
squeeze: "f32[224, 8][8, 1]cpu" = torch.ops.aten.squeeze.dim(convert_element_type, 1); convert_element_type = None
squeeze_1: "f32[224, 8][8, 1]cpu" = torch.ops.aten.squeeze.dim(convert_element_type_1, -1); convert_element_type_1 = None
mul: "f32[224, 8][8, 1]cpu" = torch.ops.aten.mul.Tensor(squeeze, squeeze_1); squeeze = squeeze_1 = None
sum_1: "f32[224, 1][1, 1]cpu" = torch.ops.aten.sum.dim_IntList(mul, [1], True); mul = None
unsqueeze: "f32[224, 1, 1][1, 1, 1]cpu" = torch.ops.aten.unsqueeze.default(sum_1, 1); sum_1 = None
convert_element_type_2: "bf16[224, 1, 1][1, 1, 1]cpu" = torch.ops.prims.convert_element_type.default(unsqueeze, torch.bfloat16); unsqueeze = None
view_2: "bf16[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.view.default(convert_element_type_2, [32, 7, 1, 1]); convert_element_type_2 = None
copy: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.copy.default(empty, view_2); empty = view_2 = None
# File: t.py:8 in matmul_outfp32, code: out = out + bias
add: "f32[32, 7, 1, 1][7, 1, 1, 1]cpu" = torch.ops.aten.add.Tensor(copy, arg2_1); copy = arg2_1 = None
return (add,)
Additionally, in eager mode, bmm_out
is used, which also runs in f32
precision.
So, for the same inputs, the computation is performed at different precisions:
- eager →
f32
(usesbmm_out
withf16
inputs andf32
outputs) - inductor →
f32
(promotes inputs tof32
before matmul) - aot_eager →
f16
(computes inf16
, then casts the result tof32
)
I suspect that aot_eager
could also use bmm_out
, but it determines that the empty
tensor isn’t really an input and therefore avoids using the in-place operation.
Should it behave this way? The torch.matmul
documentation does not specify any precision guarantees, and PyTorch’s built-in implementations are inconsistent, so it’s unclear which behavior is preferred.