在Linux下进行PyTorch的分布式训练,通常需要以下几个步骤:
环境准备:
启动分布式训练:
PyTorch提供了torch.distributed.launch
工具来简化分布式训练的启动过程。以下是一个基本的命令行示例:
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE --nnodes=NUM_NODES --node_rank=NODE_RANK --master_addr=MASTER_NODE_IP --master_port=12345 YOUR_TRAINING_SCRIPT.py
参数说明:
--nproc_per_node
:每个节点上使用的GPU数量。--nnodes
:总的节点数。--node_rank
:当前节点的排名(从0开始)。--master_addr
:主节点的IP地址。--master_port
:主节点上用于通信的端口号。YOUR_TRAINING_SCRIPT.py
:你的训练脚本。修改训练脚本: 在你的训练脚本中,需要初始化分布式环境。通常在脚本的最开始添加以下代码:
import torch.distributed as dist
dist.init_process_group(
backend='nccl', # 'nccl' is recommended for distributed GPU training
init_method='tcp://MASTER_NODE_IP:12345',
world_size=NUM_GPUS_YOU_HAVE * NUM_NODES,
rank=NODE_RANK
)
参数说明:
backend
:分布式后端,对于GPU训练推荐使用nccl
。init_method
:初始化分布式环境的地址。world_size
:总的进程数,等于GPU数量乘以节点数。rank
:当前进程的排名。数据并行:
在你的训练循环中,使用torch.nn.parallel.DistributedDataParallel
来包装你的模型:
model = YourModel().to(device)
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
参数说明:
device_ids
:指定当前进程使用的GPU ID。运行训练: 在每个节点上运行修改后的训练脚本,确保所有节点都使用相同的命令行参数。
以下是一个完整的示例:
# 在主节点上运行
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=192.168.1.1 --master_port=12345 train.py
# 在其他节点上运行
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=192.168.1.1 --master_port=12345 train.py
在train.py
中:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main():
dist.init_process_group(
backend='nccl',
init_method='tcp://192.168.1.1:12345',
world_size=8,
rank=0 # 这个rank会在每个节点上变化
)
device = torch.device(f"cuda:{dist.get_rank()}")
model = YourModel().to(device)
ddp_model = DDP(model, device_ids=[dist.get_rank()])
# 训练循环
for data, target in dataloader:
data, target = data.to(device), target.to(device)
output = ddp_model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
if __name__ == "__main__":
main()
通过以上步骤,你可以在Linux下进行PyTorch的分布式训练。确保所有节点的网络配置正确,并且防火墙允许相应的端口通信。