Summation bfloat16

Hi,

I’m seeing that torch’s summation of a bfloat16 tensor is more accurate than my naive summation, e.g. x.sum() is better than x[0] +…+ x[n-1]. Could somebody please point me towards where in the repo is the code for torch’s summation? Thanks!

sum is executed through TensorIterators, so it is a bit of a complex setup, this bit of code in ReduceOps.cpp plays a role.

The cause of your accuracy observation is likely that operations on 16 bit floats including bfloat16 typically use 32 bit floats as the internal computation (“accumulation”) scalar type.

Best regards

Thomas

Thank you for showing me that. Yes, it’s the FP32 accumulator that improves the accuracy.

That is one part of the story. The other part is that adding lots of numbers in a stable way is not a trivial task! PyTorch implements a fairly involved algorithm based on reducing parts of the tensor and then adding the results of those reductions.

The code is here: pytorch/SumKernel.cpp at bceb1db885cafa87fe8d037d8f22ae9649a1bba0 · pytorch/pytorch · GitHub

A fantastic and up to date talk on the topic with an algorithm similar (actually better) than the one that PyTorch currently implements is here: Talk by Nicholas J. Higham (University of Manchester) - YouTube

@Lezcano Thank you for sending this. I’ll check it out.