Export_memory_timeline doesn't work with torch.profile

When using PyTorch Profiler’s export_memory_timeline on a standard (eager-mode) model, the resulting HTML visualization displays memory allocations with distinct colors and labels—parameters and activations are clearly identified.
However, after switching to torch.compile, the exported memory timeline graph changes: all portions of the graph are grey sections labeled as “unknown.”

Below is a minimal reproducible test case for profiling CPU memory.

import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity
 
class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 224 * 224, 10),
        )
 
    def forward(self, x):
        return self.seq(x)
 
def run_profile_timeline():
    device = 'cpu'
    model = TestNet().to(device).eval()
    compiled_model = torch.compile(model)
    x = torch.randn(2, 3, 224, 224, device=device)
    
    with torch.no_grad():
        _ = compiled_model(x)
 
    with profile(
        activities=[ProfilerActivity.CPU],
        profile_memory=True,
        record_shapes=True,
        with_stack=True
    ) as prof:
        with record_function("compiled_model_inference"):
            _ = compiled_model(x)
    
    timeline_file = "compiled_cpu_memory_timeline.html"
    prof.export_memory_timeline(timeline_file, device=device)
    print(f"Exported memory timeline to: {timeline_file}")
 
if __name__ == "__main__":
    run_profile_timeline()