Recently, I’m playing around PyTorch’s memory allocator, and I’d like to record here what I found.
First, as we all know, PyTorch uses cuda caching allocator by default. To monitor all allocation and free from the driver level, we can try to use ld audit:
// save as audit_cuda.c
// compile with gcc -shared -fPIC -ldl -o libaudit.so audit_cuda.c -I/usr/local/cuda/include
#define _GNU_SOURCE
#include <link.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
// Assuming AUDIT_PROGRAM is defined as a string literal
#ifndef AUDIT_PROGRAM
#define AUDIT_PROGRAM "default_program"
#endif
// include cuda_runtime.h
#include <cuda_runtime.h>
typedef cudaError_t (*cudaMalloc_t)(void **, size_t);
typedef cudaError_t (*cudaFree_t)(void *);
// Global pointers to the REAL functions
static cudaMalloc_t real_cudaMalloc = NULL;
static cudaFree_t real_cudaFree = NULL;
// Interceptor for cudaMalloc
static cudaError_t my_cudaMalloc(void **devPtr, size_t size) {
printf("[Intercepted] cudaMalloc(size=%zu)\n", size);
// Call the real cudaMalloc
return real_cudaMalloc(devPtr, size);
}
// Interceptor for cudaFree
static cudaError_t my_cudaFree(void *devPtr) {
printf("[Intercepted] cudaFree(ptr=%p)\n", devPtr);
// Call the real cudaFree
return real_cudaFree(devPtr);
}
unsigned int la_version(unsigned int version) {
return LAV_CURRENT;
}
char *la_objsearch(const char *name, uintptr_t *cookie, unsigned int flag) {
return (char *)name;
}
unsigned int la_objopen(struct link_map *map, Lmid_t lmid, uintptr_t *cookie) {
*cookie = (uintptr_t)map;
return LA_FLG_BINDTO | LA_FLG_BINDFROM;
}
uintptr_t la_symbind64(Elf64_Sym *sym, unsigned int ndx, uintptr_t *refcook, uintptr_t *defcook, unsigned int *flags, const char *symname) {
struct link_map *map = (struct link_map *)*defcook;
// If we detect cudaMalloc, store the real function pointer and return our interceptor
if (strcmp(symname, "cudaMalloc") == 0 && map) {
real_cudaMalloc = (cudaMalloc_t)((uintptr_t)sym->st_value);
// Return the address of our "my_cudaMalloc" so the calls get rerouted
return (uintptr_t)my_cudaMalloc;
}
// If we detect cudaFree, store the real function pointer and return our interceptor
if (strcmp(symname, "cudaFree") == 0 && map) {
real_cudaFree = (cudaFree_t)((uintptr_t)sym->st_value);
return (uintptr_t)my_cudaFree;
}
return sym->st_value;
}
unsigned int la_objclose(uintptr_t *cookie) {
return 0;
}
void __attribute__((destructor)) finalize() {
}
The python level code:
# save as test.py
import torch
for factor in (1024, 1024 ** 2, 1024 ** 3):
print(f"Allocate {60 * factor} bytes of memory on the GPU from Python")
data = torch.empty((60, factor), dtype=torch.uint8, device="cuda")
print(f"Free {60 * factor} bytes of memory on the GPU from Python")
del data
print("Python side: memory is released")
print(f"Allocate {70 * factor} bytes of memory on the GPU from Python")
data = torch.empty((70, factor), dtype=torch.uint8, device="cuda")
print(f"Free {70 * factor} bytes of memory on the GPU from Python")
del data
print("Python side: memory is released")
Run with LD_AUDIT=$PWD/libaudit.so python test.py
The output:
Allocate 61440 bytes of memory on the GPU from Python
[Intercepted] cudaMalloc(size=2097152)
Free 61440 bytes of memory on the GPU from Python
Python side: memory is released
Allocate 71680 bytes of memory on the GPU from Python
Free 71680 bytes of memory on the GPU from Python
Python side: memory is released
Allocate 62914560 bytes of memory on the GPU from Python
[Intercepted] cudaMalloc(size=62914560)
Free 62914560 bytes of memory on the GPU from Python
Python side: memory is released
Allocate 73400320 bytes of memory on the GPU from Python
[Intercepted] cudaMalloc(size=73400320)
Free 73400320 bytes of memory on the GPU from Python
Python side: memory is released
Allocate 64424509440 bytes of memory on the GPU from Python
[Intercepted] cudaMalloc(size=64424509440)
Free 64424509440 bytes of memory on the GPU from Python
Python side: memory is released
Allocate 75161927680 bytes of memory on the GPU from Python
[Intercepted] cudaMalloc(size=75161927680)
[Intercepted] cudaFree(ptr=0x7f5af2000000)
[Intercepted] cudaFree(ptr=0x7f5aec000000)
[Intercepted] cudaFree(ptr=0x7f4be0000000)
[Intercepted] cudaFree(ptr=0x7f5b0fe00000)
[Intercepted] cudaMalloc(size=75161927680)
Free 75161927680 bytes of memory on the GPU from Python
Python side: memory is released
How to interpret the results?
- Python side, we ask for 60 KiB memory, PyTorch directly asks for 2 MiB memory from cuda.
- We release 60 KiB memory back to PyTorch. Nothing happens for cuda. PyTorch holds the 2 MiB memory.
- Python side, we ask for 70 KiB memory. Nothing happens for cuda. PyTorch holds the 2 MiB memory.
- We release 70 KiB memory back to PyTorch. Nothing happens for cuda. PyTorch holds the 2 MiB memory.
- Python side, we ask for 60 MiB memory, PyTorch directly asks for 60 MiB memory from cuda.
- We release 60 MiB memory back to PyTorch. Nothing happens for cuda. PyTorch holds the 62 MiB memory.
- Python side, we ask for 70 MiB memory, PyTorch directly asks for 70 MiB memory from cuda.
- We release 70 MiB memory back to PyTorch. Nothing happens for cuda. PyTorch holds the 132 MiB memory.
- Python side, we ask for 60 GiB memory, PyTorch directly asks for 60 GiB memory from cuda.
- We release 60 GiB memory back to PyTorch. Nothing happens for cuda. PyTorch holds the 60 GiB + 132 MiB memory.
- Python side, we ask for 70 GiB memory, PyTorch directly asks for 70 GiB memory from cuda. this is the interesting part, because PyTorch holds a lot memory, cuda does not have 70 GiB memory, and cuda will fail to allocate. Then, PyTorch will try to return the reserved but unused memory (60 GiB + 132 MiB), and try the allocation again. And the allocation succeeds.
The lesson I learned, is about the allocation and free behavior:
- When an allocation fails, PyTorch will try to release unused memory and try it again.
- When memory is freed, PyTorch does not return it back to cuda, but will reserve the memory.
While I expect the same caching behavior applies to pluggable allocator, it seems not the case. The CUDAPluggableAllocator
just routes all allocation and free requests to underlying functions, without any caching behavior.
// save as alloc.cc
// compile with g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>
// Compile with g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC
extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
void *ptr;
cudaMalloc(&ptr, size);
std::cout<<"alloc "<<ptr<< " " <<size<<std::endl;
return ptr;
}
void my_free(void* ptr, ssize_t size, int device, cudaStream_t stream) {
std::cout<<"free "<<ptr<< " "<<size<<std::endl;
cudaFree(ptr);
}
}
The python level code:
# test.py
import torch
# Load the allocator
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
'./alloc.so', 'my_malloc', 'my_free')
# Swap the current allocator
torch.cuda.memory.change_current_allocator(new_alloc)
for factor in (1024, 1024 ** 2, 1024 ** 3):
print(f"Allocate {60 * factor} bytes of memory on the GPU from Python")
data = torch.empty((60, factor), dtype=torch.uint8, device="cuda")
print(f"Free {60 * factor} bytes of memory on the GPU from Python")
del data
print("Python side: memory is released")
print(f"Allocate {70 * factor} bytes of memory on the GPU from Python")
data = torch.empty((70, factor), dtype=torch.uint8, device="cuda")
print(f"Free {70 * factor} bytes of memory on the GPU from Python")
del data
print("Python side: memory is released")
Run with python test.py
, we can get:
Allocate 61440 bytes of memory on the GPU from Python
alloc 0x7f5c2fe00000 61440
Free 61440 bytes of memory on the GPU from Python
free 0x7f5c2fe00000 61440
Python side: memory is released
Allocate 71680 bytes of memory on the GPU from Python
alloc 0x7f5c2fe00000 71680
Free 71680 bytes of memory on the GPU from Python
free 0x7f5c2fe00000 71680
Python side: memory is released
Allocate 62914560 bytes of memory on the GPU from Python
alloc 0x7f5c12000000 62914560
Free 62914560 bytes of memory on the GPU from Python
free 0x7f5c12000000 62914560
Python side: memory is released
Allocate 73400320 bytes of memory on the GPU from Python
alloc 0x7f5c10000000 73400320
Free 73400320 bytes of memory on the GPU from Python
free 0x7f5c10000000 73400320
Python side: memory is released
Allocate 64424509440 bytes of memory on the GPU from Python
alloc 0x7f4d00000000 64424509440
Free 64424509440 bytes of memory on the GPU from Python
free 0x7f4d00000000 64424509440
Python side: memory is released
Allocate 75161927680 bytes of memory on the GPU from Python
alloc 0x7f4a80000000 75161927680
Free 75161927680 bytes of memory on the GPU from Python
free 0x7f4a80000000 75161927680
Python side: memory is released
It becomes obvious, that CUDAPluggableAllocator
does not have caching. All memory requests are directly routed to the allocation and free function.
It would be better if we can have a new argument like enable_caching=True
, so that CUDAPluggableAllocator
can also reuse the caching behavior of PyTorch.
I also notice, that torch.cuda.memory_stats()
does not work for CUDAPluggableAllocator
:
RuntimeError: CUDAPluggableAllocator does not yet support getDeviceStats. If you need it, please file an issue describing your use case.
If we can enable caching for CUDAPluggableAllocator
, I think torch.cuda.memory_stats()
would also work.