在Linux系统中,使用PyTorch保存和加载模型的步骤如下:
定义模型: 首先,你需要定义你的PyTorch模型。
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = MyModel()
训练模型: 在训练过程中,你可以保存模型的权重。
# 假设你已经训练了模型
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(10):
# 训练代码...
pass
# 保存模型权重
torch.save(model.state_dict(), 'model.pth')
加载模型权重:
当你需要重新加载模型并进行推理或继续训练时,可以使用load_state_dict
方法。
# 创建相同的模型实例
model = MyModel()
# 加载模型权重
model.load_state_dict(torch.load('model.pth'))
# 如果模型在GPU上训练,需要将其移动到CPU
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
使用模型进行推理: 加载模型后,你可以使用它进行推理。
# 假设输入数据
input_data = torch.randn(1, 10)
# 使用模型进行推理
output = model(input_data)
print(output)
map_location=torch.device('cpu')
参数。通过以上步骤,你可以在Linux系统中轻松地保存和加载PyTorch模型。