ubuntu

Ubuntu系统中PyTorch内存不足怎么解决

小樊
43
2025-12-22 16:36:33
栏目: 智能运维

Ubuntu下PyTorch内存不足的排查与优化

一 快速定位问题

二 训练阶段的高效优化

三 推理阶段与常见OOM场景

四 系统与CUDA层面的调优

五 实用代码片段

import torch
from torch.cuda.amp import autocast, GradScaler

model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()

for data, target in dataloader:
    optimizer.zero_grad()
    with autocast():
        output = model(data)
        loss = criterion(output, target)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
def find_max_batch(model, input_shape, max_mem=8*1024**3):
    batch_size = 1
    while True:
        try:
            with torch.cuda.amp.autocast(enabled=True):
                _ = model(torch.randn(*input_shape, device='cuda')[:batch_size])
            used = torch.cuda.max_memory_allocated()
            if used > 0.9 * max_mem:
                return max(1, batch_size - 1)
            batch_size *= 2
        except RuntimeError:
            return max(1, batch_size // 2)
import torch, gc

def clear_cache():
    if torch.cuda.is_available():
        torch.cuda.synchronize()  # 等待所有流完成
        torch.cuda.empty_cache()
        gc.collect()

0
看了该问题的人还看了