Min-cut optimal(*) recomputation (i.e. activation checkpointing) with AOTAutograd

@kelayamatoz I think it would be totally feasible to simply modify the “ban recomputation” algorithm to add a check for operators we consider compute-bound (like matmuls) and recompute them if they’re actually bandwidth bound.

We actually have a somewhat similar check in the other direction - we ban recomputing anything (primarily reductions) where the output is >4x smaller than the input shape. The idea here is that although reductions are always bandwidth bound, it’s possible that a composition of say, broadcasting + reduction (i.e. which includes matmuls) isn’t!

1 Like