The "Ideal" PyTorch FLOP Counter (with __torch_dispatch__)

Hi,

When you do out = mm(x, y) during the forward pass, then the backward pass is give gOut and needs to compute gX and gY.
The formula for this is gX = mm(gOut, y^T) and gY = mm(x^T, gOut) that should explain why you have 2 mms in the backward pass.
Also since convolutions are just special mm, you get very similar formulas and that’s where the transpose comes from.
I would recommend you check online for the difference between transposed and regular convolutions. There are blogpost with visualizations that will be much better than anything I could write!

Cheers,
Alban

1 Like