ubuntu

Ubuntu系统中如何备份和恢复PyTorch模型

小樊
34
2025-12-20 18:04:23
栏目: 智能运维

在Ubuntu系统中备份和恢复PyTorch模型,可以按照以下步骤进行:

备份PyTorch模型

  1. 保存模型权重: 使用torch.save()函数将模型的权重保存到一个文件中。

    import torch
    import torchvision.models as models
    
    # 创建一个示例模型
    model = models.resnet18(pretrained=True)
    
    # 保存模型权重
    torch.save(model.state_dict(), 'model_weights.pth')
    
  2. 保存整个模型(可选): 如果你希望保存整个模型(包括架构和权重),可以使用torch.save()函数保存整个模型对象。

    # 保存整个模型
    torch.save(model, 'model.pth')
    

恢复PyTorch模型

  1. 加载模型权重: 使用torch.load()函数加载之前保存的模型权重,并将其加载到模型中。

    # 创建一个与之前相同的模型架构
    model = models.resnet18(pretrained=False)
    
    # 加载模型权重
    model.load_state_dict(torch.load('model_weights.pth'))
    
  2. 加载整个模型(可选): 如果你之前保存了整个模型,可以直接加载整个模型对象。

    # 加载整个模型
    model = torch.load('model.pth')
    

注意事项

示例代码总结

import torch
import torchvision.models as models

# 创建一个示例模型
model = models.resnet18(pretrained=True)

# 保存模型权重
torch.save(model.state_dict(), 'model_weights.pth')

# 加载模型权重
model = models.resnet18(pretrained=False)
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))

通过以上步骤,你可以在Ubuntu系统中轻松备份和恢复PyTorch模型。

0
看了该问题的人还看了