linux

PyTorch在Linux上的分布式训练如何操作

小樊
42
2025-08-01 12:24:09
栏目: 智能运维

在Linux上进行PyTorch的分布式训练,你需要遵循以下步骤:

  1. 环境准备

    • 确保所有节点上都安装了相同版本的PyTorch和依赖库。
    • 确保所有节点可以通过SSH无密码登录。
  2. 初始化分布式环境

    • 使用torch.distributed.launch工具或者accelerate库来启动分布式训练。
    • 设置环境变量WORLD_SIZE为参与训练的进程总数。
    • 设置环境变量RANK为当前进程的排名(从0开始)。
    • 设置环境变量MASTER_ADDR为主节点的IP地址。
    • 设置环境变量MASTER_PORT为一个未被使用的端口号。
  3. 编写分布式训练代码

    • 在代码中使用torch.distributed.init_process_group()函数来初始化分布式环境。
    • 使用torch.nn.parallel.DistributedDataParallel来包装你的模型,以便它可以在多个GPU上并行训练。
  4. 运行分布式训练

    • 使用mpirunmpiexec或者torch.distributed.launch来启动多个进程进行训练。

下面是一个简单的例子,展示了如何在Linux上使用torch.distributed.launch进行分布式训练:

# 假设你有4个GPU,并且在4个节点上进行训练
# 在每个节点上运行以下命令

# 节点0
WORLD_SIZE=4 RANK=0 MASTER_ADDR='master_ip' MASTER_PORT='12345' python -m torch.distributed.launch --nproc_per_node=4 your_training_script.py

# 节点1
WORLD_SIZE=4 RANK=1 MASTER_ADDR='master_ip' MASTER_PORT='12345' python -m torch.distributed.launch --nproc_per_node=4 your_training_script.py

# 节点2
WORLD_SIZE=4 RANK=2 MASTER_ADDR='master_ip' MASTER_PORT='12345' python -m torch.distributed.launch --nproc_per_node=4 your_training_script.py

# 节点3
WORLD_SIZE=4 RANK=3 MASTER_ADDR='master_ip' MASTER_PORT='12345' python -m torch.distributed.launch --nproc_per_node=4 your_training_script.py

your_training_script.py中,你需要初始化分布式环境并包装你的模型:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

def main(rank, world_size):
    # 初始化分布式环境
    torch.distributed.init_process_group(
        backend='nccl',  # 'nccl' is recommended for distributed GPU training
        init_method=f'tcp://{MASTER_ADDR}:{MASTER_PORT}',
        world_size=world_size,
        rank=rank
    )

    # 创建模型并移动到对应的GPU
    model = YourModel().to(rank)

    # 包装模型以进行分布式训练
    ddp_model = DDP(model, device_ids=[rank])

    # 创建损失函数和优化器
    criterion = nn.CrossEntropyLoss().to(rank)
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)

    # 训练循环...
    for epoch in range(num_epochs):
        # 训练代码...
        pass

    # 清理分布式环境
    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--rank', type=int)
    parser.add_argument('--world_size', type=int)
    args = parser.parse_args()

    main(args.rank, args.world_size)

请注意,这只是一个基本的例子。在实际应用中,你需要根据自己的需求调整代码,比如数据加载、模型定义、训练循环等。此外,如果你的集群使用的是YARN、Slurm或其他作业调度系统,你可能需要使用相应的工具来启动分布式训练。

0
看了该问题的人还看了