Hi @albanD , it’s really cool to inspect the memory usage without using any memory and I’d like to complete the PR.
Referring to the Memory Tracker for tracking Module wise memory by sanketpurandare · Pull Request #124688 · pytorch/pytorch · GitHub and FlopCounterMode
, I have refactored the code to track the memory usage of each module. Here are my initial ideas:
- Inheriting the
FlopCounterMode
to create theMemoryTrackingMode
- Using a dict,
memroy_tracker_collections
, to map the relationship between tensors and module names by tensor_ReferType
.
# The relationship between tensor and module name, {_RefType name: {tensor's storage: {module name set}}}
self.memroy_tracker_collections: Dict[str, WeakIdKeyDictionary[torch.storage.UntypedStorage, set]] = (
defaultdict(WeakIdKeyDictionary)
)
You can find the code snippet here: MemoryTrackingMode.py · GitHub. Do you have any suggestions regarding this implementation?
While testing the FlopCounterMode
with a loop, I encountered a crash with a complex model.
AssertionError: Global is not DummyModel.module
.
Here is the code to reproduce the issue. mt_with_loop.py · GitHub