以下是在Linux上优化PyTorch模型的常用方法,涵盖硬件、软件、代码及系统层面:
torch.nn.DataParallel或DistributedDataParallel实现多卡并行训练。量化:
# 静态量化示例
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
prepared_model = torch.quantization.prepare(model)
# 校准数据后转换
quantized_model = torch.quantization.convert(prepared_model)
剪枝:
torch.nn.utils.prune模块。# 结构化剪枝示例(通道剪枝)
prune.ln_structured(model.conv1, name='weight', amount=0.2, dim=0, n=2)
模型蒸馏:用大模型(教师模型)指导小模型(学生模型)训练,压缩模型规模。
混合精度训练:使用torch.cuda.amp混合FP16/FP32计算,减少显存占用并加速训练。
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
output = model(input)
loss = criterion(output, label)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
数据加载优化:
DataLoader的num_workers参数并行加载数据,配合prefetch_factor预取数据。梯度累积:通过多次小批量计算梯度后统一更新,模拟更大batch size。
sysctl或/etc/sysctl.conf优化文件描述符限制、网络参数等。nvidia-smi:监控GPU使用率、显存占用。torch.autograd.profiler:分析模型计算图瓶颈。Nsight Systems:定位CPU/GPU性能瓶颈。参考来源:[1,2,3,4,5,7,8,9,10,11]