在Ubuntu系统中使用PyTorch进行模型的保存与加载是一个常见的操作。以下是详细的步骤和示例代码:
torch.save()
函数将整个模型或仅保存模型的状态字典(state_dict)。import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(784, 10) # 假设输入是784维,输出是10类
def forward(self, x):
x = x.view(x.size(0), -1) # 将输入展平
x = self.fc(x)
return x
# 创建模型实例
model = SimpleNet()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 假设我们有一些训练数据
inputs = torch.randn(64, 1, 28, 28) # 示例输入
labels = torch.randint(0, 10, (64,)) # 示例标签
# 训练模型(这里省略了训练循环)
for epoch in range(5):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 保存整个模型
torch.save(model, 'model.pth')
# 或者仅保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')
torch.load()
函数加载模型或模型的状态字典。# 加载整个模型
loaded_model = torch.load('model.pth')
# 或者加载模型的状态字典
model = SimpleNet() # 创建一个新的模型实例
model.load_state_dict(torch.load('model_state_dict.pth'))
# 确保模型在评估模式
model.eval()
# 使用加载的模型进行预测
with torch.no_grad():
test_inputs = torch.randn(1, 1, 28, 28) # 示例测试输入
predictions = loaded_model(test_inputs)
print(predictions)
model = model.to('cpu')
通过以上步骤,你可以在Ubuntu系统中轻松地进行PyTorch模型的保存与加载。