Hey all, I’ve been working on enabling the ability to instantiate PyTorch modules with uninitialized parameters / buffers (see #29523). This functionality helps avoid unnecessary computation when doing non-standard parameter initialization, when loading from a serialized
state_dict , etc.
This is currently supported by the following 2-step process:
import torch # Initialize module on the meta device. m = torch.nn.Linear(5, 1, device='meta') # Move module to CPU with empty / uninitialized parameters. m.to_empty(device='cpu')
This is a bit arcane for the simple “skip init” use case, so I opened a PR for a helper function that will do the above in a more sugary way:
import torch # Instantiate a module with uninitialized parameters. m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
I’m hoping for comments / suggestions / name bikeshedding for the proposed sugary version. Please add your opinions if you have them Thanks!