linux

Linux PyTorch如何进行模型保存与加载

小樊
52
2025-08-07 13:16:42
栏目: 智能运维

在Linux环境下使用PyTorch进行模型的保存与加载是一个常见的操作。以下是详细的步骤和示例代码:

保存模型

  1. 保存整个模型: 使用torch.save()函数可以保存整个模型,包括模型的结构、参数和优化器状态。

    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    # 定义一个简单的模型
    class SimpleModel(nn.Module):
        def __init__(self):
            super(SimpleModel, self).__init__()
            self.fc = nn.Linear(10, 5)
    
        def forward(self, x):
            return self.fc(x)
    
    model = SimpleModel()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    
    # 保存模型
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, 'model.pth')
    
  2. 只保存模型参数: 如果只需要保存模型的参数,可以使用model.state_dict()

    torch.save(model.state_dict(), 'model_params.pth')
    

加载模型

  1. 加载整个模型: 使用torch.load()函数加载模型,并使用load_state_dict()方法将参数加载到模型中。

    # 加载模型
    model = SimpleModel()
    model.load_state_dict(torch.load('model.pth'))
    model.eval()  # 设置模型为评估模式
    
  2. 只加载模型参数: 如果之前只保存了模型参数,可以直接加载到模型中。

    model = SimpleModel()
    model.load_state_dict(torch.load('model_params.pth'))
    model.eval()  # 设置模型为评估模式
    

注意事项

通过以上步骤,你可以在Linux环境下轻松地进行PyTorch模型的保存与加载。

0
看了该问题的人还看了