linux

Linux PyTorch集群搭建指南

小樊
38
2025-10-06 00:55:06
栏目: 智能运维

Linux环境下PyTorch集群搭建指南

一、前置准备

1. 硬件要求

2. 软件与环境要求

3. 网络配置

4. SSH无密码登录

二、PyTorch安装

1. 安装PyTorch

2. 可选工具安装

三、分布式训练配置

1. 初始化进程组

在训练脚本中,使用torch.distributed.init_process_group初始化分布式环境:

import torch.distributed as dist

def setup(rank, world_size):
    # 初始化进程组,推荐使用NCCL后端(GPU加速)
    dist.init_process_group(
        backend='nccl',
        init_method='tcp://<master_ip>:<master_port>',  # 主节点IP和端口
        world_size=world_size,  # 总进程数(节点数×每个节点GPU数)
        rank=rank  # 当前进程的全局排名(0到world_size-1)
    )

def cleanup():
    dist.destroy_process_group()  # 训练结束后销毁进程组

2. 数据并行封装

使用torch.nn.parallel.DistributedDataParallel(DDP)包装模型,实现数据并行:

import torch.nn as nn

model = YourModel().to(rank)  # 将模型移动到当前GPU
ddp_model = nn.parallel.DistributedDataParallel(
    model,
    device_ids=[rank]  # 当前节点的GPU索引
)

3. 数据加载调整

使用DistributedSampler确保每个进程加载不同的数据子集:

from torch.utils.data import DataLoader, DistributedSampler

dataset = YourDataset()  # 自定义数据集
sampler = DistributedSampler(
    dataset,
    num_replicas=world_size,  # 总进程数
    rank=rank  # 当前进程排名
)
dataloader = DataLoader(
    dataset,
    batch_size=32,
    sampler=sampler  # 使用分布式采样器
)

4. 训练循环修改

在训练循环中,每轮迭代前调用sampler.set_epoch(epoch),确保数据打乱顺序:

for epoch in range(num_epochs):
    sampler.set_epoch(epoch)  # 每轮重置采样器,避免数据重复
    for data, target in dataloader:
        data, target = data.to(rank), target.to(rank)
        optimizer.zero_grad()
        output = ddp_model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

四、启动分布式训练

1. 使用torch.distributed.launch工具

在主节点运行以下命令,启动分布式训练:

python -m torch.distributed.launch \
    --nproc_per_node=<num_gpus> \  # 每个节点的GPU数量(如4)
    --nnodes=<total_nodes> \       # 总节点数(如2)
    --node_rank=<current_node_rank> \  # 当前节点排名(主节点0,工作节点1、2...)
    --master_addr=<master_ip> \    # 主节点IP
    --master_port=<port> \         # 主节点端口(如23456)
    your_training_script.py        # 训练脚本路径

2. 使用启动脚本(简化操作)

编写启动脚本start_train.sh,自动分发命令到各节点:

#!/bin/bash
# 主节点IP和端口
MASTER_ADDR="192.168.1.100"
MASTER_PORT=23456
# 总节点数
NNODES=2
# 每个节点的GPU数量
GPUS_PER_NODE=4

# 主节点运行
if [ "$1" == "master" ]; then
    echo "Starting master node..."
    python -m torch.distributed.launch \
        --nproc_per_node=$GPUS_PER_NODE \
        --nnodes=$NNODES \
        --node_rank=0 \
        --master_addr=$MASTER_ADDR \
        --master_port=$MASTER_PORT \
        train.py
# 工作节点运行
elif [ "$1" == "worker" ]; then
    echo "Starting worker node..."
    python -m torch.distributed.launch \
        --nproc_per_node=$GPUS_PER_NODE \
        --nnodes=$NNODES \
        --node_rank=1 \
        --master_addr=$MASTER_ADDR \
        --master_port=$MASTER_PORT \
        train.py
else
    echo "Usage: $0 {master|worker}"
fi

五、验证与监控

1. 验证集群功能

2. 监控工具

六、注意事项

1. 时间同步

2. 环境一致性

3. 资源分配

4. 故障排查

0
看了该问题的人还看了