在Ubuntu上使用PyTorch进行模型的保存与加载,可以按照以下步骤操作:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
torch.save()
函数来保存整个模型或模型的状态字典。model = MyModel()
# 假设模型已经训练完成
torch.save(model, 'model.pth') # 保存整个模型
# 或者
torch.save(model.state_dict(), 'model_state_dict.pth') # 只保存模型的状态字典
torch.load()
函数来加载模型。如果你之前保存了整个模型,可以直接加载;如果只保存了状态字典,则需要先实例化模型,然后再加载状态字典。# 加载整个模型
model = torch.load('model.pth')
# 或者,如果只保存了状态字典
model = MyModel() # 先实例化模型
model.load_state_dict(torch.load('model_state_dict.pth'))
注意:在加载模型时,确保模型类(在本例中为MyModel
)已经在当前环境中定义。
# 假设我们有一些输入数据
input_data = torch.randn(1, 10)
# 使用模型进行推理
output = model(input_data)
print(output)