Ubuntu下PyTorch内存优化方法
批次大小是影响GPU显存使用的核心因素之一。较小的批次大小能直接减少显存占用,但需平衡其对训练速度(如梯度更新频率)和模型性能(如泛化能力)的影响。建议通过实验找到“显存占用可接受且不影响模型效果”的最小批次值。
半精度(float16)相比单精度(float32)可减少50%的显存占用,同时通过PyTorch的torch.cuda.amp模块实现自动混合精度(AMP),能在保持模型数值稳定性的前提下,自动在float16和float32之间切换(如梯度计算用float32保证稳定性,前向/反向传播用float16提升速度)。示例代码:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for inputs, targets in dataloader:
optimizer.zero_grad()
with autocast(): # 自动混合精度上下文
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward() # 缩放梯度避免underflow
scaler.step(optimizer) # 更新参数
scaler.update() # 调整缩放因子
del关键字删除不再需要的中间变量(如损失值、预测结果),断开其对显存的引用。torch.cuda.empty_cache()释放PyTorch缓存的无用显存(如未使用的中间结果),注意该操作不会释放被引用的张量。gc.collect()强制Python垃圾回收器回收无用对象,配合del和empty_cache()效果更佳。del outputs, loss # 删除无用变量
torch.cuda.empty_cache() # 清空GPU缓存
import gc
gc.collect() # 触发垃圾回收
梯度累积通过“多次小批次计算梯度→累加→一次更新”的方式,模拟更大批次的效果,同时不增加显存占用。适用于“显存不足但需较大批次”的场景。示例代码:
accumulation_steps = 4 # 累积4个小批次的梯度
for i, (inputs, targets) in enumerate(dataloader):
outputs = model(inputs)
loss = criterion(outputs, targets)
loss = loss / accumulation_steps # 归一化损失(避免梯度爆炸)
loss.backward() # 累积梯度
if (i + 1) % accumulation_steps == 0: # 达到累积步数时更新参数
optimizer.step()
optimizer.zero_grad() # 清零梯度
DataLoader的num_workers参数(设置为CPU核心数的2-4倍)启用多进程数据加载,避免数据预处理成为瓶颈。in_features与out_features转换为卷积核的in_channels与out_channels),减少参数数量(如ResNet-50的全连接层参数占比约90%)。将模型训练分布到多个GPU(单机多卡)或多台机器(多机多卡),通过数据并行(Data Parallelism)或模型并行(Model Parallelism)减少单个设备的显存负载。PyTorch提供torch.distributed模块支持分布式训练,常用 launch 工具如torchrun。示例代码(数据并行):
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group(backend='nccl') # 初始化进程组
model = DDP(model.to(device)) # 包装模型
torch.cuda.memory_summary()打印显存分配详情(如已用/剩余显存、缓存情况),或nvidia-smi命令实时监控GPU显存使用率。torch.utils.checkpoint检查张量是否意外保留计算图(如非训练场景未用with torch.no_grad()),或使用memory_profiler库逐行跟踪内存变化。sudo echo 3 | sudo tee /proc/sys/vm/drop_caches释放系统页面缓存(不影响正在运行的程序)。sudo dd if=/dev/zero of=/swapfile bs=64M count=16,sudo mkswap /swapfile,sudo swapon /swapfile),作为物理内存的扩展(注意:Swap性能低于物理内存,仅作临时解决方案)。