linux

PyTorch在Linux上的多GPU支持

小樊
31
2025-07-01 14:14:36
栏目: 智能运维

PyTorch在Linux上支持多GPU主要通过以下几种方式实现:

数据并行

数据并行是将模型和数据分布在多个GPU上进行训练。这可以通过PyTorch的torch.nn.DataParallel类来实现。使用DataParallel可以轻松地在多个GPU上并行化模型。例如:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型
model = nn.Linear(10, 1).to('cuda')

# 使用DataParallel包装模型
if torch.cuda.device_count() > 1:
    print(f"Let's use {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练循环
for data, target in dataloader:
    data, target = data.to('cuda'), target.to('cuda')
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

模型并行

模型并行是将模型的不同部分放在不同的GPU上进行训练。这通常用于模型太大而无法放入单个GPU内存的情况。PyTorch提供了torch.nn.parallel.DistributedDataParallel类来实现分布式训练,它提供了更强大的分布式训练支持。

环境配置

为了在Linux系统上使用PyTorch的多GPU功能,需要确保系统配置正确。这包括安装NVIDIA显卡驱动、CUDA Toolkit、cuDNN以及正确设置环境变量。例如,可以使用以下命令安装推荐的NVIDIA驱动:

sudo ubuntu-drivers autoinstall

以及安装PyTorch:

conda install pytorch torchvision torchaudio cudatoolkit -c pytorch -c conda-forge

验证多GPU支持

安装完成后,可以通过以下代码验证PyTorch是否支持多GPU:

import torch
print(torch.cuda.device_count())  # 输出GPU数量
print(torch.cuda.is_available())  # 输出True,表示CUDA可用

请注意,具体的安装步骤和命令可能会随着PyTorch和CUDA版本的更新而发生变化,因此在安装时请参考最新的官方文档。

0
看了该问题的人还看了