在CentOS系统上,使用PyTorch保存和加载模型的步骤与其他操作系统相同。以下是保存和加载PyTorch模型的基本方法:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = MyModel()
# 假设模型已经训练完成
torch.save(model.state_dict(), 'model.pth')
model.state_dict()
是一个包含模型所有参数的字典。torch.save()
函数将这个字典保存到文件 model.pth
中。
# 创建相同结构的模型实例
model = MyModel()
# 加载权重
model.load_state_dict(torch.load('model.pth'))
# 如果模型是在GPU上训练的,需要将模型移动到CPU并设置为评估模式
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model.eval()
map_location
参数用于指定加载模型权重时的设备位置。如果模型是在GPU上训练的,你需要将其加载到CPU上。model.eval()
将模型设置为评估模式,这在推理时是必要的。
通过以上步骤,你可以在CentOS系统上轻松地保存和加载PyTorch模型。