在 Linux 系统中,使用 PyTorch 保存模型的方法如下:
首先,确保已经安装了 PyTorch。如果没有,请访问 PyTorch 官网(https://pytorch.org/)并按照说明进行安装。
在你的 Python 脚本或 Jupyter Notebook 中,导入 PyTorch 库:
import torch
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = torch.nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
训练模型(这里省略训练过程,仅作为示例)。
使用 torch.save()
函数保存模型。这将把整个模型(包括结构和参数)保存到一个文件中。例如,将模型保存到名为 model.pth
的文件中:
torch.save(model, 'model.pth')
现在,模型已经被保存到 model.pth
文件中。要加载模型,可以使用 torch.load()
函数:
loaded_model = torch.load('model.pth')
注意:在加载模型时,确保你的环境中已经定义了与保存模型时相同的模型结构。否则,可能会出现错误。