linux

如何在Linux中使用PyTorch进行分布式训练

小樊
45
2025-12-08 15:05:47
栏目: 智能运维

在Linux中使用PyTorch进行分布式训练,可以遵循以下步骤:

1. 环境准备

2. 配置分布式环境

3. 编写分布式训练代码

使用PyTorch的torch.distributed模块来编写分布式训练代码。以下是一个简单的示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms

def main(rank, world_size):
    # 初始化分布式环境
    dist.init_process_group(backend='nccl', init_method=f'tcp://{master_addr}:{master_port}', world_size=world_size, rank=rank)

    # 创建模型并移动到GPU
    model = nn.Linear(10, 10).to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    # 创建数据加载器
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)

    # 创建优化器
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)

    # 训练模型
    for epoch in range(5):
        sampler.set_epoch(epoch)
        for data, target in dataloader:
            data, target = data.to(rank), target.to(rank)
            optimizer.zero_grad()
            output = ddp_model(data)
            loss = nn.functional.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch}, Loss: {loss.item()}')

    # 清理分布式环境
    dist.destroy_process_group()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--world_size', type=int, default=2)
    parser.add_argument('--rank', type=int, default=0)
    args = parser.parse_args()

    main(args.rank, args.world_size)

4. 运行分布式训练

使用torch.distributed.launchmpirun来启动分布式训练。

使用torch.distributed.launch

python -m torch.distributed.launch --nproc_per_node=<number_of_gpus> --nnodes=<number_of_nodes> --node_rank=<node_rank> --master_addr=<master_ip> --master_port=<master_port> your_script.py

使用mpirun

mpirun -np <total_number_of_gpus> -host <master_ip> python your_script.py --world_size <number_of_gpus> --rank <rank>

5. 注意事项

通过以上步骤,你可以在Linux环境中使用PyTorch进行分布式训练。

0
看了该问题的人还看了