RFC: Skip Module Parameter Initialization

RFC: SKIP MODULE PARAMETER INITIALIZATION

Hey all, I’ve been working on enabling the ability to instantiate PyTorch modules with uninitialized parameters / buffers (see #29523). This functionality helps avoid unnecessary computation when doing non-standard parameter initialization, when loading from a serialized state_dict , etc.

This is currently supported by the following 2-step process:

import torch

# Initialize module on the meta device.
m = torch.nn.Linear(5, 1, device='meta')

# Move module to CPU with empty / uninitialized parameters.
m.to_empty(device='cpu')

This is a bit arcane for the simple “skip init” use case, so I opened a PR for a helper function that will do the above in a more sugary way:

import torch

# Instantiate a module with uninitialized parameters.
m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)

I’m hoping for comments / suggestions / name bikeshedding for the proposed sugary version. Please add your opinions if you have them :slight_smile: Thanks!

2 Likes