ubuntu

Ubuntu上PyTorch内存不足怎么解决

小樊
36
2025-12-10 20:31:11
栏目: 智能运维

Ubuntu上PyTorch内存不足的排查与优化

一 快速定位问题

二 通用优化清单(GPU显存)

三 系统内存不足的处理

四 推理与超大模型场景的实用建议

五 最小可用代码示例 AMP + 梯度累积 + 清理缓存

import torch, gc
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader

device = torch.device('cuda')
model = MyModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()

accum_steps = 4  # 累积4个小批量的梯度
loader = DataLoader(dataset, batch_size=16, num_workers=4, pin_memory=True)

for i, (x, y) in enumerate(loader, 1):
    x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
    optimizer.zero_grad(set_to_none=True)  # 更省显存的梯度清零方式

    with autocast():
        out = model(x)
        loss = criterion(out, y) / accum_steps  # 注意缩放

    scaler.scale(loss).backward()

    if i % accum_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

    # 可选:释放本轮不再使用的临时张量与缓存
    del x, y, out, loss
    if i % 50 == 0:
        torch.cuda.empty_cache()
        gc.collect()

0
看了该问题的人还看了