There are some recent additions for composability of Tensor Subclasses with nn.Module
.
In particular,
- An extension point for
load_state_dict
that one can use to define custom logic when loading to/from subclasses without changing the python references to the parameters. - Improved composability with
nn.Module.to
and related methods
See this tutorial for more details: Extension points in nn.Module for load_state_dict and tensor subclasses — PyTorch Tutorials 2.3.0+cu121 documentation
Let us know if you have any feedback!