Linux下PyTorch内存不足的解决方法
Batch Size是影响GPU显存占用的核心因素之一。减小训练或推理时的batch size(如从256降至128),可直接降低单次迭代中加载到GPU的数据量,从而减少显存使用。但需注意,过小的batch size可能导致模型收敛速度变慢或训练稳定性下降,需通过实验找到平衡点。
若无法进一步减小batch size(如受限于模型收敛需求),可通过梯度累积模拟大批次训练。具体做法是在多个小batch上计算梯度但不立即更新模型参数,待累积到指定步数(如4步)后再执行一次参数更新。例如:
accum_steps = 4
for i, (inputs, labels) in enumerate(dataloader):
outputs = model(inputs)
loss = criterion(outputs, labels) / accum_steps # 平均损失
loss.backward() # 累积梯度
if (i + 1) % accum_steps == 0:
optimizer.step() # 更新参数
optimizer.zero_grad() # 清空梯度
此方法可将显存需求降低至原来的1/accum_steps,适用于大batch训练场景。
PyTorch会缓存计算结果以加速后续操作,但长期运行可能导致缓存占用过多显存。可通过torch.cuda.empty_cache()手动释放未使用的缓存,尤其适用于迭代训练中不再需要的中间张量。例如,在每个epoch结束后调用该函数,清理缓存以释放空间。
混合精度训练结合了单精度(float32)和半精度(float16)计算,在保持模型精度的前提下,将参数、梯度和激活值的存储从float32转为float16,从而减少显存占用(通常可降低50%左右)。PyTorch通过torch.cuda.amp模块支持AMP,示例代码:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler() # 用于梯度缩放,防止数值溢出
for inputs, labels in dataloader:
optimizer.zero_grad()
with autocast(): # 自动选择float16/float32
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward() # 缩放梯度以避免溢出
scaler.step(optimizer) # 更新参数
scaler.update() # 调整缩放因子
该方法在不损失模型性能的情况下,显著提升训练速度并减少显存使用。
数据加载过程中的预处理(如图像缩放、归一化)或并行不足可能导致CPU成为瓶颈,进而引发GPU等待数据而浪费显存。优化方法包括:
torch.utils.data.DataLoader的num_workers参数增加数据加载的并行性(如设置为4或8,根据CPU核心数调整);pin_memory=True,将数据预加载到固定内存(Pinned Memory),加速数据从CPU到GPU的传输;对于超大规模模型(如LLaMA、GPT-3),即使使用上述方法仍可能因中间激活值过多导致显存不足。梯度检查点通过在前向传播时不保存所有中间激活值,而是在反向传播时重新计算它们,以“计算时间换显存空间”。PyTorch通过torch.utils.checkpoint模块支持该功能,示例:
from torch.utils.checkpoint import checkpoint
def forward_with_checkpoint(segment):
return segment(inputs) # segment为模型的一部分
outputs = checkpoint(forward_with_checkpoint, model.segment1) # 仅保存输入,不保存中间激活
此方法可将显存占用减少至原来的1/3~1/10,但会增加一定的计算时间。
选择参数量少、内存占用低的模型架构是根本解决之道。例如:
若单个GPU显存不足以容纳模型或batch,可使用多GPU或多机分布式训练,将模型或数据分散到多个设备上。PyTorch推荐使用DistributedDataParallel(DDP)替代DataParallel(DP),因为DDP通过多进程通信减少了主GPU的负担,提升了训练效率。示例代码:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group(backend='nccl') # 初始化分布式环境
model = DDP(model.to(device)) # 将模型包装为DDP模型
DDP可将显存占用分散到多个GPU,支持更大规模的模型训练。
实时监控GPU显存使用情况,有助于快速定位内存瓶颈。常用工具包括:
nvidia-smi:命令行工具,显示GPU利用率、显存占用等实时信息(如watch -n 1 nvidia-smi每秒刷新一次);torch.profiler.profile记录显存分配事件,分析内存使用细节(如with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], profile_memory=True) as prof:);Linux系统中,系统缓存(如Page Cache)可能占用大量内存,导致PyTorch无法分配足够显存。可通过以下命令清理缓存(需root权限):
sudo echo 3 | sudo tee /proc/sys/vm/drop_caches # 清理Page Cache、dentries和inodes
此外,关闭不必要的应用程序(如浏览器、大型软件),释放系统内存,为PyTorch提供更多可用资源。