Thanks for the pointer, it really helps!
Plus: the example of sin
in the slides seems to have no memory benefit I think. Another example you provide in Min-cut optimal(*) recomputation (i.e. activation checkpointing) with AOTAutograd did have memory benefit.
Plus: I’m confused by the word “partition”. By partition, we usually mean to partition the graph into disjoint subsets. In the following example, if we save add_2
for backward, both fwd and bwd have to compute the edge cos
on add_2
. When we partition the joint graph to get the bwd graph based on activations {x1, ..., xn}
, the fwd graph should be the minimum graph that can compute {x1, ..., xn, output}
. Those two graphs might have some overlap.
Final Plus: I used to think aot autograd can discover something like optimized sigmoid (e.g. users write eager code z = 1 / (1 + torch.exp(-x))
, and we can figure out the smart backward as z * (1 - z)
). Now that I understand what aot autograd can do.