# CentOS系统
sudo yum update -y
sudo yum install -y gcc-c++ make cmake git
# Ubuntu系统
sudo apt-get update && sudo apt-get install -y build-essential cmake git
ping <节点IP>
测试);# CentOS系统
sudo firewall-cmd --zone=public --add-port=23456/tcp --permanent
sudo firewall-cmd --reload
# Ubuntu系统(ufw)
sudo ufw allow 23456/tcp
ssh-keygen -t rsa # 直接回车,默认保存路径~/.ssh/id_rsa
ssh-copy-id user@worker1_ip # 替换为工作节点用户名和IP
ssh-copy-id user@worker2_ip
ssh user@worker1_ip # 无需输入密码即可登录
pip3 install torch torchvision torchaudio
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
import torch
print(torch.cuda.is_available()) # 应输出True
pip3 install dask distributed # Dask安装
pip3 install mpi4py
在训练脚本中,使用torch.distributed.init_process_group
初始化分布式环境:
import torch.distributed as dist
def setup(rank, world_size):
# 初始化进程组,推荐使用NCCL后端(GPU加速)
dist.init_process_group(
backend='nccl',
init_method='tcp://<master_ip>:<master_port>', # 主节点IP和端口
world_size=world_size, # 总进程数(节点数×每个节点GPU数)
rank=rank # 当前进程的全局排名(0到world_size-1)
)
def cleanup():
dist.destroy_process_group() # 训练结束后销毁进程组
使用torch.nn.parallel.DistributedDataParallel
(DDP)包装模型,实现数据并行:
import torch.nn as nn
model = YourModel().to(rank) # 将模型移动到当前GPU
ddp_model = nn.parallel.DistributedDataParallel(
model,
device_ids=[rank] # 当前节点的GPU索引
)
使用DistributedSampler
确保每个进程加载不同的数据子集:
from torch.utils.data import DataLoader, DistributedSampler
dataset = YourDataset() # 自定义数据集
sampler = DistributedSampler(
dataset,
num_replicas=world_size, # 总进程数
rank=rank # 当前进程排名
)
dataloader = DataLoader(
dataset,
batch_size=32,
sampler=sampler # 使用分布式采样器
)
在训练循环中,每轮迭代前调用sampler.set_epoch(epoch)
,确保数据打乱顺序:
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # 每轮重置采样器,避免数据重复
for data, target in dataloader:
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = ddp_model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
在主节点运行以下命令,启动分布式训练:
python -m torch.distributed.launch \
--nproc_per_node=<num_gpus> \ # 每个节点的GPU数量(如4)
--nnodes=<total_nodes> \ # 总节点数(如2)
--node_rank=<current_node_rank> \ # 当前节点排名(主节点0,工作节点1、2...)
--master_addr=<master_ip> \ # 主节点IP
--master_port=<port> \ # 主节点端口(如23456)
your_training_script.py # 训练脚本路径
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="192.168.1.100" --master_port=23456 train.py
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="192.168.1.100" --master_port=23456 train.py
编写启动脚本start_train.sh
,自动分发命令到各节点:
#!/bin/bash
# 主节点IP和端口
MASTER_ADDR="192.168.1.100"
MASTER_PORT=23456
# 总节点数
NNODES=2
# 每个节点的GPU数量
GPUS_PER_NODE=4
# 主节点运行
if [ "$1" == "master" ]; then
echo "Starting master node..."
python -m torch.distributed.launch \
--nproc_per_node=$GPUS_PER_NODE \
--nnodes=$NNODES \
--node_rank=0 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train.py
# 工作节点运行
elif [ "$1" == "worker" ]; then
echo "Starting worker node..."
python -m torch.distributed.launch \
--nproc_per_node=$GPUS_PER_NODE \
--nnodes=$NNODES \
--node_rank=1 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train.py
else
echo "Usage: $0 {master|worker}"
fi
chmod +x start_train.sh
./start_train.sh master
./start_train.sh worker
rank
和world_size
:print(f"Rank {rank}, World Size {world_size}")
rank
(0到world_size-1)。http://<master_ip>:8787
查看任务进度和集群状态;torch.profiler
监控训练性能,识别瓶颈:with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
on_trace_ready=lambda prof: prof.export_chrome_trace("trace.json")
) as prof:
for data, target in dataloader:
# 训练代码
prof.step()
# CentOS系统
sudo yum install -y ntp
sudo systemctl start ntpd
sudo systemctl enable ntpd
# Ubuntu系统
sudo apt-get install -y ntp
sudo systemctl start ntp
sudo systemctl enable ntp
requirements.txt
或environment.yml
统一所有节点的Python环境:pip freeze > requirements.txt # 导出主节点环境
# 在工作节点安装相同环境
pip install -r requirements.txt
--nproc_per_node
,避免资源浪费;world_size
等于“节点数×每个节点GPU数”。ping <节点IP>
;ssh user@worker_ip
;train.log
),定位错误信息。