在Linux上,使用PyTorch保存和加载模型的过程相对简单。以下是详细的步骤:
torch.save()
函数来保存整个模型或模型的状态字典。import torch
import torch.nn as nn
# 假设你有一个模型类
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
# 创建模型实例
model = MyModel()
# 训练模型...
# 保存整个模型
torch.save(model, 'model.pth')
# 或者只保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')
torch.load()
函数来加载模型或模型的状态字典。# 加载整个模型
model = torch.load('model.pth')
# 或者加载模型的状态字典
model = MyModel() # 创建一个新的模型实例
model.load_state_dict(torch.load('model_state_dict.pth'))
设备兼容性:如果你在GPU上训练模型,保存的模型会包含GPU相关的信息。在加载模型时,需要确保模型在相同的设备上(CPU或GPU)。
# 如果模型是在GPU上训练的,加载到CPU上
model = torch.load('model.pth', map_location=torch.device('cpu'))
# 如果模型是在CPU上训练的,加载到GPU上
model = torch.load('model.pth', map_location=torch.device('cuda'))
版本兼容性:确保保存和加载模型的PyTorch版本兼容。不同版本的PyTorch可能会有不同的模型格式。
自定义对象:如果你使用了自定义的损失函数、优化器或其他自定义对象,需要在加载模型时提供这些对象的定义。
# 假设你有一个自定义的损失函数
class CustomLoss(nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()
self.loss_fn = nn.MSELoss()
def forward(self, outputs, targets):
return self.loss_fn(outputs, targets)
# 加载模型时提供自定义对象
model = torch.load('model.pth', map_location=torch.device('cpu'), object_hook=lambda d: {k: v for k, v in d.items() if not k.startswith('__')})
model.loss_fn = CustomLoss()
通过以上步骤,你可以在Linux上轻松地保存和加载PyTorch模型。