1.you means: make nn.module parameters are inputs to the graph rather than attributes of the outer graphmodule that can make autograd call the backward hook?or just make FSDP run correctly?
2.the problem of overlap betweent bw compute and reduce gradients is workaround by the “UnspecializedNNModuleVariable”,how the UnspecializedNNModuleVariable resolve the problem?where can find the reference about the “UnspecializedNNModuleVariable”?
3. can lazy tensor core trace collective communicate ops?I found pytorch/xla can trace these ops ,it implement by custom plugin,is it the right way to compile the model with collective ops?
Thanks