在Linux上使用PyTorch进行GPU加速,需按以下步骤操作:
nvidia-smi
命令检查驱动是否正常。PATH
和LD_LIBRARY_PATH
)。pip install torch --extra-index-url https://download.pytorch.org/whl/cu117
),安装时需指定cudatoolkit
版本。import torch; print(torch.cuda.is_available())
,若输出True
则配置成功。device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_tensor = input_tensor.to(device)
注意:
torch.nn.DataParallel
或DistributedDataParallel
。