This is a very cool feature! Formulating this problem as a min-cut is quite neat!
Quick question: Is there a way to configure recomputable_ops based on their compute / data bandwidth ratio? It looks to me that the current recomputable_ops list is hard-coded based on the types of operators, which could miss some optimization opportunities.
For example, let’s say that we are looking at A @ B with two configurations:
Config. 1: A.shape = (100, 100), B.shape = (100, 100), Dim_M = Dim_K = Dim_N = 100
Config. 2: A.shape = (100, 2), B.shape = (2, 2), Dim_M = 100, Dim_K = 2, Dim_N = 2
Config 1 and 2 are both MatMuls and should be unrecomputable_ops. However, Config 2 makes compute / memory footprint of A @ B look more like an element-wise operator in the sense that ops per Dim_M is very low (Dim_K and Dim_N << Dim_M). Therefore, Config 2 actually makes A @ B a good candidate to recompute.
If we want to make Min-cut recomputation shape-aware, what in your view could be a good way to approach this?
Thanks!