1. 降低批次大小(Batch Size)
批次大小是影响GPU显存占用的核心因素之一。较小的批次会减少单次前向/反向传播所需的内存,但需平衡其对训练速度(如梯度更新频率)和模型性能(如泛化能力)的影响。建议通过实验找到“内存占用可接受且训练效率无明显下降”的最优批次值。
2. 使用混合精度训练(Automatic Mixed Precision, AMP)
混合精度结合float16(半精度)和float32(单精度)计算,在保持模型数值稳定性的同时,将显存占用减少约50%。PyTorch通过torch.cuda.amp模块实现自动混合精度:
scaler = torch.cuda.amp.GradScaler() # 梯度缩放器(防止数值溢出)
for data, label in dataloader:
optimizer.zero_grad()
with torch.cuda.amp.autocast(): # 自动选择float16/float32
output = model(data)
loss = criterion(output, label)
scaler.scale(loss).backward() # 缩放梯度以避免溢出
scaler.step(optimizer) # 更新参数
scaler.update() # 调整缩放因子
该技术尤其适合NVIDIA Volta架构及以上的GPU(如T4、A100)。
3. 释放不必要的张量与缓存
del关键字删除不再需要的张量(如中间结果、临时变量),减少引用计数。torch.cuda.empty_cache()释放PyTorch缓存池中未使用的显存(注意:此操作会触发同步,可能短暂降低性能,建议在调试或非训练阶段使用)。import gc; gc.collect()手动启动Python垃圾回收机制,清理无引用的对象。4. 优化数据加载流程
数据加载是内存瓶颈的常见来源,可通过以下方式优化:
DataLoader的num_workers参数(如num_workers=4),利用多进程并行读取数据,避免CPU成为瓶颈。albumentations),或直接在GPU上进行预处理(如torchvision.transforms的GPU版本)。5. 使用梯度累积(Gradient Accumulation)
梯度累积允许在多个小批次上累积梯度,再进行一次参数更新,从而在不增加显存占用的情况下,模拟更大批次的效果(如累积步数=4相当于将批次大小扩大4倍)。示例代码:
accumulation_steps = 4
for i, (data, label) in enumerate(dataloader):
output = model(data)
loss = criterion(output, label)
loss = loss / accumulation_steps # 归一化损失(避免梯度爆炸)
loss.backward() # 累积梯度
if (i + 1) % accumulation_steps == 0:
optimizer.step() # 更新参数
optimizer.zero_grad() # 清零梯度
该方法适用于显存不足以支持大批次训练的场景。
6. 释放计算图引用
PyTorch的动态计算图会保留中间结果以支持自动微分。若未正确断开引用,会导致显存无法释放。常见解决方法:
detach():将不需要梯度的张量从计算图中分离(如output = model(data).detach())。torch.no_grad():在推理或验证阶段禁用梯度计算(如with torch.no_grad(): output = model(data)),彻底避免计算图生成。7. 优化模型结构
选择或设计内存高效的模型结构,减少参数数量和显存占用:
1x1卷积可实现通道降维),适合处理图像数据。8. 监控与调试内存使用
通过工具实时监控显存占用,快速定位内存泄漏或瓶颈:
torch.cuda.memory_summary()查看显存分配详情(如已用/预留显存、张量数量),或torch.cuda.memory_allocated()/torch.cuda.memory_reserved()获取实时显存值。profile_memory=True),记录每个操作的显存消耗(如with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], profile_memory=True) as prof: ...)。9. 系统级优化
sudo sync && echo 3 | sudo tee /proc/sys/vm/drop_caches命令清理Ubuntu系统的页面缓存(Page Cache),释放物理内存(不影响正在运行的进程)。sudo dd if=/dev/zero of=/swapfile bs=64M count=16 && sudo mkswap /swapfile && sudo swapon /swapfile),避免程序因内存耗尽而崩溃(注意:Swap速度远低于物理内存,仅作为临时解决方案)。torch.distributed模块),通过数据并行(Data Parallel)或模型并行(Model Parallel)减少单个设备的显存负载。