Strange Increase of non-torch memory for unexpected functions

Hi, team, I’m investigating how to measure the cuda memory taken by non-torch components, like NCCL. I have come up with a function that seems to do the trick:

def measure_current_non_torch():
    free, total = torch.cuda.mem_get_info()
    current_used = total - free
    current_torch = torch.cuda.memory_reserved()
    current_non_torch = current_used - current_torch
    return current_non_torch

If I just allocate tensors in PyTorch, I find this gives a consistent value, the only non-torch memory is cuda context.

However, I find that, when I do some computation, unexpected non-torch memory occurs.

I try to trace every line of python code to monitor the change of non-torch memory:

import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator, Generic, TypeVar

_T = TypeVar("_T")


@dataclasses.dataclass
class MonitoredValues(Generic[_T]):
    values: list[_T] = dataclasses.field(default_factory=list)
    trace_stacks: list[str] = dataclasses.field(default_factory=list)


@contextlib.contextmanager
def monitor(
    measure_func: Callable[[],
                           _T]) -> Generator[MonitoredValues[_T], None, None]:
    """
    Trace the function calls to continuously monitor the change of
    a value.

    Usage:

    ```python

    def measure_func():
        ... # measure the current value
        return current_value

    with monitor(measure_func) as monitored_values:
        # do something
    
        monitored_values.values # all changes of the values
        monitored_values.trace_stacks # trace stacks of every change
    ```
    """
    monitored_values = MonitoredValues[_T]()

    def _trace_calls(frame, event, arg=None):
        nonlocal monitored_values
        if event in ['line']:
            # triggered by every line of Python code.
            # only Python functions will trigger it,
            # c/cpp functions will not trigger it.
            try:
                # Temporarily disable the trace function
                sys.settrace(None)
                # do a measurement
                current_value = measure_func()
                if len(monitored_values.values
                       ) == 0 or current_value != monitored_values.values[-1]:
                    monitored_values.values.append(current_value)
                    monitored_values.trace_stacks.append("".join(
                        traceback.format_stack()))
                # Re-enable the trace function
                sys.settrace(_trace_calls)
            except NameError:
                # modules are deleted during shutdown
                pass
        return _trace_calls

    try:
        sys.settrace(_trace_calls)
        yield monitored_values
    finally:
        sys.settrace(None)

import torch
import torchvision

def f():
    net = torchvision.models.resnet50().cuda()
    inputs = torch.randn((64, 3, 224, 224)).cuda()
    outputs = net(inputs)

def measure_current_non_torch():
    free, total = torch.cuda.mem_get_info()
    current_used = total - free
    current_torch = torch.cuda.memory_reserved()
    current_non_torch = current_used - current_torch
    return current_non_torch

with monitor(measure_current_non_torch) as monitor_values:
    f()

for value, stack in zip(monitor_values.values, \
        monitor_values.trace_stacks):
    print(f"non_torch memory changed to {value / 1024 / 1024} MiB in\n")
    print(stack + "\n")

And the outputs are surprising:

non_torch memory changed to 529.0625 MiB in

  File "/data/youkaichao/vllm/testf.py", line 86, in <module>
    f()
  File "/data/youkaichao/vllm/testf.py", line 74, in f
    net = torchvision.models.resnet50().cuda()
  File "/data/youkaichao/vllm/testf.py", line 56, in _trace_calls
    traceback.format_stack()))


non_torch memory changed to 541.0625 MiB in

  File "/data/youkaichao/vllm/testf.py", line 86, in <module>
    f()
  File "/data/youkaichao/vllm/testf.py", line 76, in f
    outputs = net(inputs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 285, in forward
    return self._forward_impl(x)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 269, in _forward_impl
    x = self.bn1(x)
  File "/data/youkaichao/vllm/testf.py", line 56, in _trace_calls
    traceback.format_stack()))


non_torch memory changed to 631.0625 MiB in

  File "/data/youkaichao/vllm/testf.py", line 86, in <module>
    f()
  File "/data/youkaichao/vllm/testf.py", line 76, in f
    outputs = net(inputs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 285, in forward
    return self._forward_impl(x)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 269, in _forward_impl
    x = self.bn1(x)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 174, in forward
    if self.momentum is None:  # use cumulative moving average
  File "/data/youkaichao/vllm/testf.py", line 56, in _trace_calls
    traceback.format_stack()))


non_torch memory changed to 633.0625 MiB in

  File "/data/youkaichao/vllm/testf.py", line 86, in <module>
    f()
  File "/data/youkaichao/vllm/testf.py", line 76, in f
    outputs = net(inputs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 285, in forward
    return self._forward_impl(x)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 273, in _forward_impl
    x = self.layer1(x)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/container.py", line 250, in forward
    input = module(input)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 155, in forward
    out = self.bn3(out)
  File "/data/youkaichao/vllm/testf.py", line 56, in _trace_calls
    traceback.format_stack()))


non_torch memory changed to 697.0625 MiB in

  File "/data/youkaichao/vllm/testf.py", line 86, in <module>
    f()
  File "/data/youkaichao/vllm/testf.py", line 76, in f
    outputs = net(inputs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 285, in forward
    return self._forward_impl(x)
  File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 282, in _forward_impl
    return x
  File "/data/youkaichao/vllm/testf.py", line 56, in _trace_calls
    traceback.format_stack()))

In particular, x = self.bn1(x) will cause non-torch memory increase.

I also have a very long profile for running an LLM: tracing_non_torch_memory_for_rank_0_in_tp_1.txt - Google Drive

It is even stranger, with cudagraph capture, running a linear layer will cause a sudden increase and decrease in non-torch memory.

So, my question is:

is my measure_current_non_torch function accurate? Anything I missed?

Thanks!