Distributed Debug Server

Hi folks,

We’ve been working on creating an interactive “Debug Server” that consists of an interactive frontend along with HTTP handlers that run on every rank. The key idea is to create a place to store generally useful tools for debugging distributed jobs that can run across all workers simultaneously. We think this will be generally useful for debugging distributed jobs but initially focused on hangs and slow performance.

Current list of tools (landed/inprogress):

  • Python Stack Traces (only Python)
  • py-spy Native Stack traces (Python/C++)
  • FlightRecorder analysis+raw json for CPU and NCCL
  • torch.profiler support on demand (integrated into Perfetto)
  • WaitCounters
  • TCPStore debugging

The frontend server runs in a subprocess of rank 0 and when requested will do fan out HTTP requests to every single worker to get the corresponding data.

We’ve tested this for 100k ranks with simple requests but we expect for certain outputs things may get much slower if traces are huge etc and we may need to iterate on these to make them smarter.

We’d love to get feedback on what tools are useful and any issues you run into. This is also designed to very hackable and easy to add tools to so we’d welcome any PRs/new analyzers.

Here’s a very simple version that enables this:

"""
Invoke with:

torchrun --nnodes 1 --nproc_per_node=gpu ~/scripts/debug_test.py
"""

import os
import time

os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from torch.distributed.debug import start_debug_server

start_debug_server()

RANK = int(os.environ["RANK"])
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
device = f"cuda:{LOCAL_RANK}"

dist.init_process_group("nccl")

# Simple model
model = nn.Sequential(
    nn.Linear(10, 2048),
    nn.ReLU(),
    nn.Linear(2048, 1)
).to(device)

model = DDP(model, device_ids=[LOCAL_RANK])

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
# Toy data
X = torch.randn(100, 10, device=device)
y = torch.randn(100, 1, device=device)
# Training loop
for epoch in range(1000000000):
    optimizer.zero_grad()
    pred = model(X)
    loss = loss_fn(pred, y)
    loss.backward()
    optimizer.step()

    if RANK == 0:
        time.sleep(1)

    dist.barrier()

    if epoch % 1 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

References:

Examples:

1 Like

Its really a great feature, can we also run py-spy for debugging hang issues of data loader process

1 Like

dist/debug: support py-spy native stacks by d4l3k · Pull Request #169147 · pytorch/pytorch · GitHub enables py-spy

Should be pretty simple to add a follow up PR to enable --subprocesses optionally

@irshadcc added native and subprocess options