Hi,
When you do out = mm(x, y)
during the forward pass, then the backward pass is give gOut
and needs to compute gX
and gY
.
The formula for this is gX = mm(gOut, y^T)
and gY = mm(x^T, gOut)
that should explain why you have 2 mms in the backward pass.
Also since convolutions are just special mm, you get very similar formulas and that’s where the transpose comes from.
I would recommend you check online for the difference between transposed and regular convolutions. There are blogpost with visualizations that will be much better than anything I could write!
Cheers,
Alban