How to measure memory usage from your model without running it?

The short answer is this 30 lines TorchDispatchMode that tracks all Tensor memory use: subclass_zoo/max_mem_tracker.py at main · albanD/subclass_zoo · GitHub

# Track all the memory being used by Tensors.
# Only max is tracked but others can be added.
MEMORY_USE = WeakIdKeyDictionary()
MEMORY_MAX = 0
# Minimum allocation size 
PYTORCH_MIN_ALLOCATE = 2**9

def update_stats():
    global MEMORY_MAX
    curr_use = 0
    for k, v in MEMORY_USE.items():
        curr_use += math.ceil(k.size() * k.element_size()/PYTORCH_MIN_ALLOCATE) * PYTORCH_MIN_ALLOCATE

    if MEMORY_MAX < curr_use:
        MEMORY_MAX = curr_use

# Should be called on every Tensor created
def track(t:torch.Tensor):
    def cb(_):
        update_stats()
    st = t.untyped_storage()
    wt = weakref.ref(st, cb)
    MEMORY_USE[st] = wt
    update_stats()

# Use this Mode to call track on every Tensor being created by functions
class MemoryTrackingMode(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        res = func(*args, **kwargs or {})

        tree_map_only(torch.Tensor, track, res)
        return res

This has the added benefit of working under FakeTensorMode meaning that you can measure memory usage without actually using any memory! Even better in (most) cases, you can measure memory use on cuda using a cpu-only build.

with FakeTensorMode(), MemoryTrackingMode():
    def f(a):
        b = a * 10
        d = b + 3
        return d

    # This works even on a cpu-only build!
    a = torch.rand(100, device="cuda")
    f(a)
    print(f"Just f: {MEMORY_MAX}")

What is next?

Tell us what you want!

  • Memory Tracker for tracking Module wise memory by sanketpurandare · Pull Request #124688 · pytorch/pytorch · GitHub is using this and adding a nn.Module tracker to provide Module-aware memory usage.
    • We should refactor the Module tracking so it can be shared with the FlopCounter code
    • If anyone wants to help get that PR merged faster, feel free to reach out!
  • Are you interested in memory fragmentation? Should we have a way to mimick the behavior of our caching allocator to allow us to share not only max memory allocated but also max memory reserved?
    • I think we should be able to do that either by refactoring the caching allocator to be called from python or mimicking it’s behavior in python.
  • Do we want to have some kind of integration with the memory visualizer from Understanding GPU Memory 1: Visualizing All Allocations over Time | PyTorch to allow generating these memory profiles from within FakeTensorMode?
    • I think we should be able to do that by either re-using the capture logic in the caching allocator or by generating the same trace directly from our mockup.
  • Other features you would like to see there?
8 Likes

Hi @albanD , it’s really cool to inspect the memory usage without using any memory and I’d like to complete the PR.

Referring to the Memory Tracker for tracking Module wise memory by sanketpurandare · Pull Request #124688 · pytorch/pytorch · GitHub and FlopCounterMode, I have refactored the code to track the memory usage of each module. Here are my initial ideas:

  • Inheriting the FlopCounterMode to create the MemoryTrackingMode
  • Using a dict, memroy_tracker_collections, to map the relationship between tensors and module names by tensor _ReferType.
        # The relationship between tensor and module name, {_RefType name: {tensor's storage: {module name set}}}
        self.memroy_tracker_collections: Dict[str, WeakIdKeyDictionary[torch.storage.UntypedStorage, set]] = (
            defaultdict(WeakIdKeyDictionary)
        )

You can find the code snippet here: MemoryTrackingMode.py · GitHub. Do you have any suggestions regarding this implementation?

While testing the FlopCounterMode with a loop, I encountered a crash with a complex model.
AssertionError: Global is not DummyModel.module.
Here is the code to reproduce the issue. mt_with_loop.py · GitHub

Hey!

Thanks for looking into this.

I wasn’t expecting to have the memory tracker inherit from the flop counter. I was more thinking of having a third class that is the Module tracker. And then both feature would inherit from that.

Also I haven’t looked at how the module tracker is implemented but I suspect we can significantly simplify it by use the all-module hooks so that we don’t need the user to pass in their module and we don’t need to manage so many hooks.