ubuntu

Ubuntu PyTorch如何进行模型保存与加载

小樊
35
2025-08-05 18:46:06
栏目: 智能运维

在Ubuntu系统中使用PyTorch进行模型的保存与加载是一个常见的操作。以下是详细的步骤和示例代码:

保存模型

  1. 定义模型:首先,你需要定义一个模型。
  2. 训练模型:在训练过程中,模型会不断更新其权重。
  3. 保存模型:使用torch.save()函数将整个模型或仅保存模型的状态字典(state_dict)。

示例代码

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(784, 10)  # 假设输入是784维,输出是10类

    def forward(self, x):
        x = x.view(x.size(0), -1)  # 将输入展平
        x = self.fc(x)
        return x

# 创建模型实例
model = SimpleNet()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 假设我们有一些训练数据
inputs = torch.randn(64, 1, 28, 28)  # 示例输入
labels = torch.randint(0, 10, (64,))  # 示例标签

# 训练模型(这里省略了训练循环)
for epoch in range(5):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

# 保存整个模型
torch.save(model, 'model.pth')

# 或者仅保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')

加载模型

  1. 加载模型:使用torch.load()函数加载模型或模型的状态字典。
  2. 加载状态字典:如果之前保存的是状态字典,需要先创建一个模型实例,然后加载状态字典。

示例代码

# 加载整个模型
loaded_model = torch.load('model.pth')

# 或者加载模型的状态字典
model = SimpleNet()  # 创建一个新的模型实例
model.load_state_dict(torch.load('model_state_dict.pth'))

# 确保模型在评估模式
model.eval()

# 使用加载的模型进行预测
with torch.no_grad():
    test_inputs = torch.randn(1, 1, 28, 28)  # 示例测试输入
    predictions = loaded_model(test_inputs)
    print(predictions)

注意事项

  1. 设备兼容性:如果模型是在GPU上训练的,保存的模型文件默认也会包含GPU信息。在加载到CPU上进行推理时,需要将模型移动到CPU:
    model = model.to('cpu')
    
  2. 版本兼容性:不同版本的PyTorch可能会有不同的模型保存格式,确保加载模型的PyTorch版本与保存模型的版本兼容。

通过以上步骤,你可以在Ubuntu系统中轻松地进行PyTorch模型的保存与加载。

0
看了该问题的人还看了