When using PyTorch Profiler’s export_memory_timeline on a standard (eager-mode) model, the resulting HTML visualization displays memory allocations with distinct colors and labels—parameters and activations are clearly identified.
However, after switching to torch.compile, the exported memory timeline graph changes: all portions of the graph are grey sections labeled as “unknown.”
Below is a minimal reproducible test case for profiling CPU memory.
import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity
class TestNet(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Flatten(),
nn.Linear(32 * 224 * 224, 10),
)
def forward(self, x):
return self.seq(x)
def run_profile_timeline():
device = 'cpu'
model = TestNet().to(device).eval()
compiled_model = torch.compile(model)
x = torch.randn(2, 3, 224, 224, device=device)
with torch.no_grad():
_ = compiled_model(x)
with profile(
activities=[ProfilerActivity.CPU],
profile_memory=True,
record_shapes=True,
with_stack=True
) as prof:
with record_function("compiled_model_inference"):
_ = compiled_model(x)
timeline_file = "compiled_cpu_memory_timeline.html"
prof.export_memory_timeline(timeline_file, device=device)
print(f"Exported memory timeline to: {timeline_file}")
if __name__ == "__main__":
run_profile_timeline()