The short answer is this 30 lines TorchDispatchMode that tracks all Tensor memory use: subclass_zoo/max_mem_tracker.py at main · albanD/subclass_zoo · GitHub
# Track all the memory being used by Tensors.
# Only max is tracked but others can be added.
MEMORY_USE = WeakIdKeyDictionary()
MEMORY_MAX = 0
# Minimum allocation size
PYTORCH_MIN_ALLOCATE = 2**9
def update_stats():
global MEMORY_MAX
curr_use = 0
for k, v in MEMORY_USE.items():
curr_use += math.ceil(k.size() * k.element_size()/PYTORCH_MIN_ALLOCATE) * PYTORCH_MIN_ALLOCATE
if MEMORY_MAX < curr_use:
MEMORY_MAX = curr_use
# Should be called on every Tensor created
def track(t:torch.Tensor):
def cb(_):
update_stats()
st = t.untyped_storage()
wt = weakref.ref(st, cb)
MEMORY_USE[st] = wt
update_stats()
# Use this Mode to call track on every Tensor being created by functions
class MemoryTrackingMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
res = func(*args, **kwargs or {})
tree_map_only(torch.Tensor, track, res)
return res
This has the added benefit of working under FakeTensorMode meaning that you can measure memory usage without actually using any memory! Even better in (most) cases, you can measure memory use on cuda using a cpu-only build.
with FakeTensorMode(), MemoryTrackingMode():
def f(a):
b = a * 10
d = b + 3
return d
# This works even on a cpu-only build!
a = torch.rand(100, device="cuda")
f(a)
print(f"Just f: {MEMORY_MAX}")
What is next?
Tell us what you want!
- Memory Tracker for tracking Module wise memory by sanketpurandare · Pull Request #124688 · pytorch/pytorch · GitHub is using this and adding a nn.Module tracker to provide Module-aware memory usage.
- We should refactor the Module tracking so it can be shared with the FlopCounter code
- If anyone wants to help get that PR merged faster, feel free to reach out!
- Are you interested in memory fragmentation? Should we have a way to mimick the behavior of our caching allocator to allow us to share not only max memory allocated but also max memory reserved?
- I think we should be able to do that either by refactoring the caching allocator to be called from python or mimicking it’s behavior in python.
- Do we want to have some kind of integration with the memory visualizer from Understanding GPU Memory 1: Visualizing All Allocations over Time | PyTorch to allow generating these memory profiles from within FakeTensorMode?
- I think we should be able to do that by either re-using the capture logic in the caching allocator or by generating the same trace directly from our mockup.
- Other features you would like to see there?