Extension points for Tensor subclasses composability with `nn.Module`

There are some recent additions for composability of Tensor Subclasses with nn.Module.

In particular,

  • An extension point for load_state_dict that one can use to define custom logic when loading to/from subclasses without changing the python references to the parameters.
  • Improved composability with nn.Module.to and related methods

See this tutorial for more details: Extension points in nn.Module for load_state_dict and tensor subclasses — PyTorch Tutorials 2.3.0+cu121 documentation

Let us know if you have any feedback!

5 Likes