PyTorch在Linux上的分布式训练方法主要包括以下几个步骤:
安装PyTorch:
配置环境:
网络配置:
使用torch.distributed.init_process_group函数来初始化分布式环境。
import torch
import torch.distributed as dist
dist.init_process_group(
backend='nccl', # 'nccl' for GPU, 'gloo' for CPU
init_method='tcp://<master_ip>:<master_port>', # e.g., 'tcp://192.168.1.1:23456'
world_size=<world_size>, # total number of processes
rank=<rank> # rank 0 is the master, others are workers
)
使用torch.nn.parallel.DistributedDataParallel来包装你的模型。
model = YourModel().to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
使用torch.utils.data.distributed.DistributedSampler来确保每个进程处理不同的数据子集。
from torch.utils.data import DataLoader, DistributedSampler
dataset = YourDataset()
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=<batch_size>, sampler=sampler)
在训练循环中,每个进程都会执行自己的训练步骤。
for epoch in range(num_epochs):
sampler.set_epoch(epoch)
for data, target in dataloader:
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
DistributedDataParallel会自动处理梯度的同步。
在所有进程完成后,可以保存模型。
if rank == 0:
torch.save(model.state_dict(), 'model.pth')
训练结束后,记得清理分布式环境。
dist.destroy_process_group()
以下是一个完整的示例代码,展示了如何在Linux上进行PyTorch的分布式训练:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, DistributedSampler
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 初始化分布式环境
dist.init_process_group(
backend='nccl',
init_method='tcp://192.168.1.1:23456',
world_size=4,
rank=0
)
# 定义模型
class YourModel(nn.Module):
def __init__(self):
super(YourModel, self).__init__()
# 定义你的模型结构
def forward(self, x):
# 定义前向传播
return x
model = YourModel().to(rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
# 数据加载
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练循环
for epoch in range(5):
sampler.set_epoch(epoch)
for data, target in dataloader:
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if rank == 0:
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
# 保存模型
if rank == 0:
torch.save(model.state_dict(), 'model.pth')
# 清理分布式环境
dist.destroy_process_group()
init_method中的端口没有被其他进程占用。通过以上步骤,你可以在Linux上使用PyTorch进行高效的分布式训练。