linux

Linux下PyTorch的分布式训练如何实现

小樊
41
2025-09-06 19:36:17
栏目: 智能运维

在Linux下使用PyTorch进行分布式训练,主要涉及以下几个步骤:

  1. 环境准备

    • 确保所有节点上都安装了相同版本的PyTorch和依赖库。
    • 配置SSH无密码登录,以便各个节点之间可以互相通信。
  2. 初始化进程组

    • 使用torch.distributed.init_process_group()函数来初始化分布式环境。这个函数需要几个参数,包括后端(如ncclgloo等)、初始化方法(如tcp://)、IP地址和端口号。
  3. 数据并行

    • 使用torch.nn.parallel.DistributedDataParallel(DDP)来包装你的模型。DDP会自动处理数据的分片和梯度的聚合。
  4. 数据加载

    • 使用torch.utils.data.distributed.DistributedSampler来确保每个进程处理数据集的不同部分。
  5. 训练循环

    • 在训练循环中,每个进程都会执行前向传播、计算损失、反向传播和参数更新。
  6. 保存模型

    • 如果需要保存模型,确保只在主进程中执行保存操作,以避免冲突。

下面是一个简单的示例代码,展示了如何在Linux下使用PyTorch进行分布式训练:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 初始化分布式环境
world_size = 4  # 假设有4个GPU
rank = 0  # 当前进程的rank
master_ip = '192.168.1.1'  # 主节点的IP地址
master_port = '12345'  # 主节点的端口号
torch.distributed.init_process_group(
    backend='nccl',
    init_method=f'tcp://{master_ip}:{master_port}',
    world_size=world_size,
    rank=rank
)

# 定义模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])

# 数据加载
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)

# 优化器
optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)

# 训练循环
for epoch in range(5):
    sampler.set_epoch(epoch)
    running_loss = 0.0
    for data, target in dataloader:
        data, target = data.to(rank), target.to(rank)
        optimizer.zero_grad()
        output = ddp_model(data)
        loss = nn.functional.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}')

# 保存模型(只在主进程中执行)
if rank == 0:
    torch.save(ddp_model.state_dict(), 'model.pth')

# 清理分布式环境
torch.distributed.destroy_process_group()

注意事项:

通过以上步骤,你可以在Linux环境下使用PyTorch实现分布式训练。

0
看了该问题的人还看了