centos

CentOS上PyTorch模型如何保存与加载

小樊
44
2025-04-15 00:35:44
栏目: 智能运维

在CentOS上,使用PyTorch保存和加载模型的过程与其他操作系统相同。以下是保存和加载PyTorch模型的基本步骤:

  1. 首先,确保已经安装了PyTorch。如果还没有安装,可以参考PyTorch官方网站上的安装指南:https://pytorch.org/get-started/locally/

  2. 创建一个简单的PyTorch模型。例如,我们可以创建一个简单的多层感知器(MLP):

import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

input_size = 10
hidden_size = 5
output_size = 2

model = MLP(input_size, hidden_size, output_size)
  1. 训练模型。这里我们省略了训练过程,假设模型已经训练完成。

  2. 保存模型。使用torch.save()函数将模型保存到文件中:

torch.save(model.state_dict(), 'model.pth')

这将把模型的权重和偏置保存到名为model.pth的文件中。

  1. 加载模型。使用torch.load()函数从文件中加载模型,并使用model.load_state_dict()函数将权重和偏置加载到模型中:
loaded_model = MLP(input_size, hidden_size, output_size)
loaded_model.load_state_dict(torch.load('model.pth'))

现在,loaded_model变量包含了从文件中加载的模型,可以像使用原始模型一样使用它。

注意:在加载模型时,确保模型的结构与保存时的结构相同。否则,可能会出现错误。

0
看了该问题的人还看了