linux

Linux中PyTorch的模型保存与加载方法

小樊
42
2025-08-05 13:00:49
栏目: 智能运维

在Linux系统中,使用PyTorch保存和加载模型的步骤如下:

保存模型

  1. 定义模型: 首先,你需要定义你的PyTorch模型。

    import torch
    import torch.nn as nn
    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            self.fc = nn.Linear(10, 5)
    
        def forward(self, x):
            return self.fc(x)
    
    model = MyModel()
    
  2. 训练模型: 在训练过程中,你可以保存模型的权重。

    # 假设你已经训练了模型
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    for epoch in range(10):
        # 训练代码...
        pass
    
    # 保存模型权重
    torch.save(model.state_dict(), 'model.pth')
    

加载模型

  1. 加载模型权重: 当你需要重新加载模型并进行推理或继续训练时,可以使用load_state_dict方法。

    # 创建相同的模型实例
    model = MyModel()
    
    # 加载模型权重
    model.load_state_dict(torch.load('model.pth'))
    
    # 如果模型在GPU上训练,需要将其移动到CPU
    model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
    
  2. 使用模型进行推理: 加载模型后,你可以使用它进行推理。

    # 假设输入数据
    input_data = torch.randn(1, 10)
    
    # 使用模型进行推理
    output = model(input_data)
    print(output)
    

注意事项

通过以上步骤,你可以在Linux系统中轻松地保存和加载PyTorch模型。

0
看了该问题的人还看了