linux

Linux上PyTorch模型如何保存与加载

小樊
47
2025-06-21 10:42:48
栏目: 智能运维

在Linux上,使用PyTorch保存和加载模型的过程相对简单。以下是详细的步骤:

保存模型

  1. 训练模型:首先,你需要训练你的模型。
  2. 保存模型:使用torch.save()函数来保存整个模型或模型的状态字典。
import torch
import torch.nn as nn

# 假设你有一个模型类
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 5)

    def forward(self, x):
        return self.fc(x)

# 创建模型实例
model = MyModel()

# 训练模型...

# 保存整个模型
torch.save(model, 'model.pth')

# 或者只保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')

加载模型

  1. 加载模型:使用torch.load()函数来加载模型或模型的状态字典。
# 加载整个模型
model = torch.load('model.pth')

# 或者加载模型的状态字典
model = MyModel()  # 创建一个新的模型实例
model.load_state_dict(torch.load('model_state_dict.pth'))

注意事项

  1. 设备兼容性:如果你在GPU上训练模型,保存的模型会包含GPU相关的信息。在加载模型时,需要确保模型在相同的设备上(CPU或GPU)。

    # 如果模型是在GPU上训练的,加载到CPU上
    model = torch.load('model.pth', map_location=torch.device('cpu'))
    
    # 如果模型是在CPU上训练的,加载到GPU上
    model = torch.load('model.pth', map_location=torch.device('cuda'))
    
  2. 版本兼容性:确保保存和加载模型的PyTorch版本兼容。不同版本的PyTorch可能会有不同的模型格式。

  3. 自定义对象:如果你使用了自定义的损失函数、优化器或其他自定义对象,需要在加载模型时提供这些对象的定义。

    # 假设你有一个自定义的损失函数
    class CustomLoss(nn.Module):
        def __init__(self):
            super(CustomLoss, self).__init__()
            self.loss_fn = nn.MSELoss()
    
        def forward(self, outputs, targets):
            return self.loss_fn(outputs, targets)
    
    # 加载模型时提供自定义对象
    model = torch.load('model.pth', map_location=torch.device('cpu'), object_hook=lambda d: {k: v for k, v in d.items() if not k.startswith('__')})
    model.loss_fn = CustomLoss()
    

通过以上步骤,你可以在Linux上轻松地保存和加载PyTorch模型。

0
看了该问题的人还看了