ubuntu

Ubuntu下PyTorch的内存管理怎样优化

小樊
36
2025-11-25 02:18:51
栏目: 智能运维

Ubuntu下PyTorch内存优化实用指南

一 基础显存优化

二 进阶显存节省技术

三 监控与诊断工具

四 OOM应急与排查流程

五 一键可用的最小优化模板

import torch, torch.nn as nn, torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

device = torch.device('cuda')
model = nn.Linear(1024, 1024).to(device)
optimizer = optim.Adam(model.parameters())
scaler = GradScaler()

accumulation_steps = 4
for i, (x, y) in enumerate(dataloader):
    x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
    with autocast():
        loss = criterion(model(x), y) / accumulation_steps
    scaler.scale(loss).backward()

    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)  # 更省显存
# 可选:阶段结束后清理
# del x, y, loss
# torch.cuda.empty_cache()
model.eval()
with torch.no_grad():
    for x in dataloader:
        x = x.to(device, non_blocking=True)
        out = model(x)
# 可选:清理
# del x, out
# torch.cuda.empty_cache()
def log_mem(msg=""):
    a = torch.cuda.memory_allocated() / 1024**2
    r = torch.cuda.memory_reserved() / 1024**2
    print(f"{msg} Allocated={a:.1f}MB Reserved={r:.1f}MB")
loader = DataLoader(dataset, batch_size=bs, num_workers=4,
                  pin_memory=True, prefetch_factor=2)
export PYTORCH_CUDA_ALLOC_CONF="garbage_collection_threshold:0.8,max_split_size_mb:128"

以上模板覆盖了混合精度梯度累积no_grad清理缓存监控等关键要点,可直接嵌入现有训练脚本并根据显存余量微调参数。

0
看了该问题的人还看了