在Linux环境下使用PyTorch进行模型的保存与加载是一个常见的操作。以下是详细的步骤和示例代码:
保存整个模型:
使用torch.save()
函数可以保存整个模型,包括模型的结构、参数和优化器状态。
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 保存模型
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'model.pth')
只保存模型参数:
如果只需要保存模型的参数,可以使用model.state_dict()
。
torch.save(model.state_dict(), 'model_params.pth')
加载整个模型:
使用torch.load()
函数加载模型,并使用load_state_dict()
方法将参数加载到模型中。
# 加载模型
model = SimpleModel()
model.load_state_dict(torch.load('model.pth'))
model.eval() # 设置模型为评估模式
只加载模型参数: 如果之前只保存了模型参数,可以直接加载到模型中。
model = SimpleModel()
model.load_state_dict(torch.load('model_params.pth'))
model.eval() # 设置模型为评估模式
设备兼容性:如果模型是在GPU上训练的,而加载时在CPU上,需要将模型参数移动到CPU。
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
版本兼容性:确保保存和加载模型的PyTorch版本一致,否则可能会出现兼容性问题。
优化器状态:如果保存了优化器状态,可以在加载模型后继续训练。
optimizer.load_state_dict(torch.load('model.pth')['optimizer_state_dict'])
通过以上步骤,你可以在Linux环境下轻松地进行PyTorch模型的保存与加载。