linux

PyTorch在Linux上的分布式训练如何实现

小樊
32
2025-06-21 10:41:53
栏目: 智能运维

PyTorch在Linux上的分布式训练可以通过多种方式实现,主要包括使用torch.distributed包和torch.multiprocessing模块。以下是一个基本的步骤指南:

  1. 环境准备

    • 确保所有节点上都安装了相同版本的PyTorch和依赖库。
    • 设置好网络环境,确保节点间可以互相通信。
  2. 代码编写

    • 使用torch.distributed.launch工具或者自定义脚本来启动分布式训练。
    • 在代码中,使用torch.distributed.init_process_group来初始化分布式环境。
    • 使用torch.nn.parallel.DistributedDataParallel来包装模型,以实现数据的并行处理。
  3. 运行分布式训练

    • 在每个节点上运行相同的训练脚本,但是需要设置不同的命令行参数,如节点的IP地址、端口号、总节点数、当前节点的排名等。

以下是一个简单的示例代码,展示了如何在PyTorch中实现分布式训练:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
import torchvision.datasets as datasets
import torchvision.transforms as transforms

def main(rank, world_size):
    # 初始化分布式环境
    dist.init_process_group(backend='nccl', init_method='tcp://<master_ip>:<master_port>', world_size=world_size, rank=rank)

    # 创建模型并移动到对应的GPU
    model = nn.Sequential(
        nn.Linear(784, 1024),
        nn.ReLU(),
        nn.Linear(1024, 10)
    ).to(rank)

    # 使用DistributedDataParallel包装模型
    model = DDP(model, device_ids=[rank])

    # 加载数据集
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=64, sampler=sampler)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    # 训练模型
    for epoch in range(5):
        sampler.set_epoch(epoch)
        running_loss = 0.0
        for data, target in loader:
            data, target = data.to(rank), target.to(rank)
            optimizer.zero_grad()
            output = model(data.view(data.size(0), -1))
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Rank {rank}, Epoch {epoch}, Loss: {running_loss/len(loader)}')

    # 清理分布式环境
    dist.destroy_process_group()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--world_size', type=int, default=4, help='number of nodes for distributed training')
    parser.add_argument('--rank', type=int, default=0, help='node rank')
    parser.add_argument('--master_ip', type=str, default='127.0.0.1', help='master node IP address')
    parser.add_argument('--master_port', type=str, default='12345', help='master node port')
    args = parser.parse_args()

    main(args.rank, args.world_size)

在运行上述代码之前,你需要确保所有节点上都安装了PyTorch,并且网络设置正确。然后,你可以使用torch.distributed.launch来启动分布式训练:

python -m torch.distributed.launch --nproc_per_node=<num_gpus_per_node> --nnodes=<num_nodes> --node_rank=<node_rank> --master_addr=<master_ip> --master_port=<master_port> your_training_script.py

其中<num_gpus_per_node>是每个节点上的GPU数量,<num_nodes>是总节点数,<node_rank>是当前节点的排名,<master_ip><master_port>是主节点的IP地址和端口号。

请注意,这只是一个基本的示例,实际的分布式训练可能会涉及更复杂的网络配置、数据加载和模型同步策略。此外,对于大规模的分布式训练,可能还需要考虑使用专门的调度系统,如Slurm或Kubernetes。

0
看了该问题的人还看了