Torch.optim happenings: More Ws (2023 H2)

@raghukiran1224 Please correct me if I misunderstood–I’m guessing you are asking about torch.compile() composing with FSDP for optimizers?

In theory, the fusion should work independently of whether the inputs are full vs sharded for pointwise optimizers (all our foreach optimizers). Inductor should be able to fuse a series of foreach operations on the same memory without trouble. I am less sure about fusing with the ops before and after the optimizer step, as torch.compile() support on FSDP distributed collectives is underway (see an update here: Torch.compile() + FSDP - Dec 8th) and @awgu would know more about the ops directly before and after the optimizer portions in FSDP.

Practically speaking, since distributed collectives are not compilable today, there’s no full-graph capture of FSDP. In other words, only the code regions between the collectives (like allgather, reduce_scatter) would be compiled. A similar chopped up approach would have to take place for optimizer–>one would manually replace the normal optimizer.step() call with a torch.compile()'d version. In pseudocode (I have not ascertained this works in Python), I’m imagining something like:

model = ...
fsdp_mod = FSDP(model)
optim = AdamW(fsdp_mod.parameters())
optim.step = torch.compile(optim.step)  # WARN: this may be invalid python

That said, I suspect no one on the team has tried this yet, but we are actively developing so if you do get a chance to try it, please let us know how it goes! Also cc’ing @wconstab, @mlazos, @voz, @awgu who would know more.