Hi, team, I’m investigating how to measure the cuda memory taken by non-torch components, like NCCL. I have come up with a function that seems to do the trick:
def measure_current_non_torch():
free, total = torch.cuda.mem_get_info()
current_used = total - free
current_torch = torch.cuda.memory_reserved()
current_non_torch = current_used - current_torch
return current_non_torch
If I just allocate tensors in PyTorch, I find this gives a consistent value, the only non-torch memory is cuda context.
However, I find that, when I do some computation, unexpected non-torch memory occurs.
I try to trace every line of python code to monitor the change of non-torch memory:
import contextlib
import dataclasses
import sys
import traceback
from typing import Callable, Generator, Generic, TypeVar
_T = TypeVar("_T")
@dataclasses.dataclass
class MonitoredValues(Generic[_T]):
values: list[_T] = dataclasses.field(default_factory=list)
trace_stacks: list[str] = dataclasses.field(default_factory=list)
@contextlib.contextmanager
def monitor(
measure_func: Callable[[],
_T]) -> Generator[MonitoredValues[_T], None, None]:
"""
Trace the function calls to continuously monitor the change of
a value.
Usage:
```python
def measure_func():
... # measure the current value
return current_value
with monitor(measure_func) as monitored_values:
# do something
monitored_values.values # all changes of the values
monitored_values.trace_stacks # trace stacks of every change
```
"""
monitored_values = MonitoredValues[_T]()
def _trace_calls(frame, event, arg=None):
nonlocal monitored_values
if event in ['line']:
# triggered by every line of Python code.
# only Python functions will trigger it,
# c/cpp functions will not trigger it.
try:
# Temporarily disable the trace function
sys.settrace(None)
# do a measurement
current_value = measure_func()
if len(monitored_values.values
) == 0 or current_value != monitored_values.values[-1]:
monitored_values.values.append(current_value)
monitored_values.trace_stacks.append("".join(
traceback.format_stack()))
# Re-enable the trace function
sys.settrace(_trace_calls)
except NameError:
# modules are deleted during shutdown
pass
return _trace_calls
try:
sys.settrace(_trace_calls)
yield monitored_values
finally:
sys.settrace(None)
import torch
import torchvision
def f():
net = torchvision.models.resnet50().cuda()
inputs = torch.randn((64, 3, 224, 224)).cuda()
outputs = net(inputs)
def measure_current_non_torch():
free, total = torch.cuda.mem_get_info()
current_used = total - free
current_torch = torch.cuda.memory_reserved()
current_non_torch = current_used - current_torch
return current_non_torch
with monitor(measure_current_non_torch) as monitor_values:
f()
for value, stack in zip(monitor_values.values, \
monitor_values.trace_stacks):
print(f"non_torch memory changed to {value / 1024 / 1024} MiB in\n")
print(stack + "\n")
And the outputs are surprising:
non_torch memory changed to 529.0625 MiB in
File "/data/youkaichao/vllm/testf.py", line 86, in <module>
f()
File "/data/youkaichao/vllm/testf.py", line 74, in f
net = torchvision.models.resnet50().cuda()
File "/data/youkaichao/vllm/testf.py", line 56, in _trace_calls
traceback.format_stack()))
non_torch memory changed to 541.0625 MiB in
File "/data/youkaichao/vllm/testf.py", line 86, in <module>
f()
File "/data/youkaichao/vllm/testf.py", line 76, in f
outputs = net(inputs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 285, in forward
return self._forward_impl(x)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 269, in _forward_impl
x = self.bn1(x)
File "/data/youkaichao/vllm/testf.py", line 56, in _trace_calls
traceback.format_stack()))
non_torch memory changed to 631.0625 MiB in
File "/data/youkaichao/vllm/testf.py", line 86, in <module>
f()
File "/data/youkaichao/vllm/testf.py", line 76, in f
outputs = net(inputs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 285, in forward
return self._forward_impl(x)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 269, in _forward_impl
x = self.bn1(x)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 174, in forward
if self.momentum is None: # use cumulative moving average
File "/data/youkaichao/vllm/testf.py", line 56, in _trace_calls
traceback.format_stack()))
non_torch memory changed to 633.0625 MiB in
File "/data/youkaichao/vllm/testf.py", line 86, in <module>
f()
File "/data/youkaichao/vllm/testf.py", line 76, in f
outputs = net(inputs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 285, in forward
return self._forward_impl(x)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 273, in _forward_impl
x = self.layer1(x)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/container.py", line 250, in forward
input = module(input)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 155, in forward
out = self.bn3(out)
File "/data/youkaichao/vllm/testf.py", line 56, in _trace_calls
traceback.format_stack()))
non_torch memory changed to 697.0625 MiB in
File "/data/youkaichao/vllm/testf.py", line 86, in <module>
f()
File "/data/youkaichao/vllm/testf.py", line 76, in f
outputs = net(inputs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 285, in forward
return self._forward_impl(x)
File "/data/youkaichao/uv_envs/py310/lib/python3.10/site-packages/torchvision/models/resnet.py", line 282, in _forward_impl
return x
File "/data/youkaichao/vllm/testf.py", line 56, in _trace_calls
traceback.format_stack()))
In particular, x = self.bn1(x)
will cause non-torch memory increase.
I also have a very long profile for running an LLM: tracing_non_torch_memory_for_rank_0_in_tp_1.txt - Google Drive
It is even stranger, with cudagraph capture, running a linear layer will cause a sudden increase and decrease in non-torch memory.
So, my question is:
is my measure_current_non_torch
function accurate? Anything I missed?
Thanks!