Hi folks,
We’ve been working on creating an interactive “Debug Server” that consists of an interactive frontend along with HTTP handlers that run on every rank. The key idea is to create a place to store generally useful tools for debugging distributed jobs that can run across all workers simultaneously. We think this will be generally useful for debugging distributed jobs but initially focused on hangs and slow performance.
Current list of tools (landed/inprogress):
- Python Stack Traces (only Python)
- py-spy Native Stack traces (Python/C++)
- FlightRecorder analysis+raw json for CPU and NCCL
- torch.profiler support on demand (integrated into Perfetto)
- WaitCounters
- TCPStore debugging
The frontend server runs in a subprocess of rank 0 and when requested will do fan out HTTP requests to every single worker to get the corresponding data.
We’ve tested this for 100k ranks with simple requests but we expect for certain outputs things may get much slower if traces are huge etc and we may need to iterate on these to make them smarter.
We’d love to get feedback on what tools are useful and any issues you run into. This is also designed to very hackable and easy to add tools to so we’d welcome any PRs/new analyzers.
Here’s a very simple version that enables this:
"""
Invoke with:
torchrun --nnodes 1 --nproc_per_node=gpu ~/scripts/debug_test.py
"""
import os
import time
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.debug import start_debug_server
start_debug_server()
RANK = int(os.environ["RANK"])
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
device = f"cuda:{LOCAL_RANK}"
dist.init_process_group("nccl")
# Simple model
model = nn.Sequential(
nn.Linear(10, 2048),
nn.ReLU(),
nn.Linear(2048, 1)
).to(device)
model = DDP(model, device_ids=[LOCAL_RANK])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
# Toy data
X = torch.randn(100, 10, device=device)
y = torch.randn(100, 1, device=device)
# Training loop
for epoch in range(1000000000):
optimizer.zero_grad()
pred = model(X)
loss = loss_fn(pred, y)
loss.backward()
optimizer.step()
if RANK == 0:
time.sleep(1)
dist.barrier()
if epoch % 1 == 0:
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
References:
- Entrypoint pytorch/torch/distributed/debug/__init__.py at main · pytorch/pytorch · GitHub
- Frontend pytorch/torch/distributed/debug/_frontend.py at main · pytorch/pytorch · GitHub
- C++ handler/WorkerServer impl pytorch/torch/csrc/distributed/c10d/control_plane at main · pytorch/pytorch · GitHub
- A number of PRs are landing adding more features at dist/debug: support py-spy native stacks by d4l3k · Pull Request #169147 · pytorch/pytorch · GitHub
Examples:




