How to measure memory usage from your model without running it?

Hi @albanD , it’s really cool to inspect the memory usage without using any memory and I’d like to complete the PR.

Referring to the Memory Tracker for tracking Module wise memory by sanketpurandare · Pull Request #124688 · pytorch/pytorch · GitHub and FlopCounterMode, I have refactored the code to track the memory usage of each module. Here are my initial ideas:

  • Inheriting the FlopCounterMode to create the MemoryTrackingMode
  • Using a dict, memroy_tracker_collections, to map the relationship between tensors and module names by tensor _ReferType.
        # The relationship between tensor and module name, {_RefType name: {tensor's storage: {module name set}}}
        self.memroy_tracker_collections: Dict[str, WeakIdKeyDictionary[torch.storage.UntypedStorage, set]] = (
            defaultdict(WeakIdKeyDictionary)
        )

You can find the code snippet here: MemoryTrackingMode.py · GitHub. Do you have any suggestions regarding this implementation?

While testing the FlopCounterMode with a loop, I encountered a crash with a complex model.
AssertionError: Global is not DummyModel.module.
Here is the code to reproduce the issue. mt_with_loop.py · GitHub