在Ubuntu系统中备份和恢复PyTorch模型,可以按照以下步骤进行:
保存模型权重:
使用torch.save()函数将模型的权重保存到一个文件中。
import torch
import torchvision.models as models
# 创建一个示例模型
model = models.resnet18(pretrained=True)
# 保存模型权重
torch.save(model.state_dict(), 'model_weights.pth')
保存整个模型(可选):
如果你希望保存整个模型(包括架构和权重),可以使用torch.save()函数保存整个模型对象。
# 保存整个模型
torch.save(model, 'model.pth')
加载模型权重:
使用torch.load()函数加载之前保存的模型权重,并将其加载到模型中。
# 创建一个与之前相同的模型架构
model = models.resnet18(pretrained=False)
# 加载模型权重
model.load_state_dict(torch.load('model_weights.pth'))
加载整个模型(可选): 如果你之前保存了整个模型,可以直接加载整个模型对象。
# 加载整个模型
model = torch.load('model.pth')
设备一致性:在加载模型权重时,确保模型和权重在同一设备上(CPU或GPU)。如果模型在GPU上训练,但在CPU上加载,需要将权重移动到CPU。
# 如果模型在GPU上训练,但在CPU上加载
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))
模型架构一致性:确保加载权重的模型架构与保存权重的模型架构一致。如果不一致,可能会导致加载失败或模型行为异常。
import torch
import torchvision.models as models
# 创建一个示例模型
model = models.resnet18(pretrained=True)
# 保存模型权重
torch.save(model.state_dict(), 'model_weights.pth')
# 加载模型权重
model = models.resnet18(pretrained=False)
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))
通过以上步骤,你可以在Ubuntu系统中轻松备份和恢复PyTorch模型。