Hi @albanD , it’s really cool to inspect the memory usage without using any memory and I’d like to complete the PR.
Referring to the Memory Tracker for tracking Module wise memory by sanketpurandare · Pull Request #124688 · pytorch/pytorch · GitHub and FlopCounterMode, I have refactored the code to track the memory usage of each module. Here are my initial ideas:
- Inheriting the
FlopCounterModeto create theMemoryTrackingMode - Using a dict,
memroy_tracker_collections, to map the relationship between tensors and module names by tensor_ReferType.
# The relationship between tensor and module name, {_RefType name: {tensor's storage: {module name set}}}
self.memroy_tracker_collections: Dict[str, WeakIdKeyDictionary[torch.storage.UntypedStorage, set]] = (
defaultdict(WeakIdKeyDictionary)
)
You can find the code snippet here: MemoryTrackingMode.py · GitHub. Do you have any suggestions regarding this implementation?
While testing the FlopCounterMode with a loop, I encountered a crash with a complex model.
AssertionError: Global is not DummyModel.module.
Here is the code to reproduce the issue. mt_with_loop.py · GitHub