Agent skill

debug-distributed

Guide for debugging distributed training issues in AReaL. Use when user encounters hangs, wrong results, OOM, or communication errors.

Stars 163
Forks 31

Install this agent skill to your Project

npx add-skill https://github.com/majiayu000/claude-skill-registry/tree/main/skills/data/debug-distributed

SKILL.md

Debug Distributed Training

Debugging guide for distributed training issues in AReaL (FSDP2, TP, CP, EP).

When to Use

This skill is triggered when:

  • Training hangs or deadlocks
  • Results differ across ranks or are numerically wrong
  • OOM errors in distributed settings
  • NCCL/communication errors or device mesh issues

Debugging Principles

Minimal Reproduction

Always follow the minimal demo principle: Reproduce with the least amount of code to narrow down the issue faster.

python
# Bad: Debug in full training loop
# Good: Create minimal script
import torch
import torch.distributed as dist

dist.init_process_group("nccl")
rank = dist.get_rank()

# Reproduce the exact operation that fails
tensor = torch.ones(10).cuda()
dist.all_reduce(tensor)  # <-- Isolate the failing op
print(f"Rank {rank}: {tensor}")

Reduction strategy:

  1. Remove unrelated model components
  2. Use small tensor sizes
  3. Reduce world_size to minimum (e.g., 2 GPUs)
  4. Remove torch.compile if possible
  5. Disable activation checkpointing

Step-by-Step Debugging Guide

1. Hang Debugging (Deadlocks, Synchronization)

Environment Variables for Debugging:

bash
# Full debug logging
export TORCH_DISTRIBUTED_DEBUG=DETAIL
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL

# torch.compile debugging
export TORCH_LOGS="+dynamo,recompiles"
export TORCHDYNAMO_VERBOSE=1

Dump Call Stack with py-spy (for hung processes):

bash
# Find process IDs
ps aux | grep python

# Dump call stack of specific rank
py-spy dump --pid <PID>

# Record flame graph for performance analysis
py-spy record -o profile.svg --pid <PID> --duration 30

Common Causes:

  1. Mismatched Collectives: One rank calls all_reduce, another doesn't.
  2. Wrong Process Group: Using wrong group for collective.
  3. Tensor Shape Mismatch: Different shapes across ranks.

Debug Steps:

python
# Verify group membership
mesh = parallel_dims.get_mesh("dp_shard_cp")
group = mesh.get_group()
print(f"Rank {dist.get_rank()}: group size = {dist.get_world_size(group)}")

# Print shapes on all ranks
print(f"Rank {dist.get_rank()}: tensor.shape = {tensor.shape}")
dist.barrier()

Timeout Adjustment (for debugging only):

python
from areal.utils.distributed import patch_dist_group_timeout
from datetime import timedelta
patch_dist_group_timeout(timedelta(minutes=30))

2. Wrong Results (Gradient, Reduction Issues)

Check DTensor Placements:

python
from torch.distributed.tensor import DTensor
if isinstance(param, DTensor):
    print(f"Param {name}: placements={param.placements}, mesh={param.device_mesh}")

Verify Gradient Reduction:

python
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"Rank {dist.get_rank()}: {name} grad_sum = {param.grad.sum().item()}")

3. OOM Issues (Memory, Sharding)

Check Memory Usage:

python
print(f"Rank {dist.get_rank()}: "
      f"allocated={torch.cuda.memory_allocated()/1e9:.2f}GB, "
      f"reserved={torch.cuda.memory_reserved()/1e9:.2f}GB")

Check FSDP Coverage:

python
for name, param in model.named_parameters():
    is_dtensor = isinstance(param, DTensor)
    print(f"{name}: is_dtensor={is_dtensor}, shape={param.shape}")

4. Communication Errors

Error Cause Solution
NCCL WARN Cuda failure GPU communication Check NCCL version, GPU topology
RuntimeError: Timed out Rank synchronization Increase timeout, check code paths
Invalid device mesh Mesh configuration Verify world_size = dp * tp * cp

Debugging Tools

Environment Variables Reference

Variable Purpose
TORCH_DISTRIBUTED_DEBUG=DETAIL Detailed distributed logging
NCCL_DEBUG=INFO NCCL communication logging
NCCL_DEBUG_SUBSYS=ALL All NCCL subsystems
TORCH_LOGS="+dynamo,recompiles" torch.compile logging
TORCHDYNAMO_VERBOSE=1 Dynamo verbose output
CUDA_LAUNCH_BLOCKING=1 Synchronous CUDA (slow, for debugging)

py-spy for Call Stack Analysis

bash
# Install
pip install py-spy

# Dump call stack of hung process
py-spy dump --pid <PID>

# Dump all Python processes
pgrep -f python | xargs -I {} py-spy dump --pid {}

# Record flame graph
py-spy record -o profile.svg --pid <PID> --duration 30

Rank-Conditional Printing

python
def print_all_ranks(msg):
    for r in range(dist.get_world_size()):
        if dist.get_rank() == r:
            print(f"[Rank {r}] {msg}")
        dist.barrier()

Check Device Mesh

python
def debug_mesh(parallel_dims):
    mesh = parallel_dims.world_mesh
    for dim_name in mesh.mesh_dim_names:
        submesh = parallel_dims.get_mesh(dim_name)
        if submesh:
            print(f"Rank {dist.get_rank()}: {dim_name} size={submesh.size()}")

Validate Tensor Consistency

python
def check_tensor_consistency(tensor, name, group=None):
    local_sum = tensor.sum().item()
    tensor_sums = [None] * dist.get_world_size(group)
    dist.all_gather_object(tensor_sums, local_sum, group=group)
    if dist.get_rank() == 0 and len(set(tensor_sums)) > 1:
        print(f"WARNING: {name} inconsistent: {tensor_sums}")

Key Files Reference

Component File
Parallel Dims areal/experimental/models/archon/parallel_dims.py
Expert Parallel areal/experimental/models/archon/expert_parallel.py
Ulysses (CP) areal/experimental/models/archon/ulysses.py
FSDP/TP Apply areal/experimental/models/archon/qwen2/infra/parallelize.py

Didn't find tool you were looking for?

Be as detailed as possible for better results