在Linux系统中,使用PyTorch保存模型的主要方法是使用torch.save()
函数。以下是一个简单的示例:
import torch
import torchvision.models as models
# 创建一个预训练的ResNet18模型
model = models.resnet18(pretrained=True)
# 保存整个模型
torch.save(model, 'resnet18_model.pth')
# 如果只想保存模型的状态字典(即权重和偏置),可以使用以下方法:
torch.save(model.state_dict(), 'resnet18_state_dict.pth')
在这个例子中,我们首先导入了torch
库和torchvision.models
模块。然后,我们创建了一个预训练的ResNet18模型。接下来,我们使用torch.save()
函数将整个模型保存到一个名为resnet18_model.pth
的文件中。此外,我们还可以选择仅保存模型的状态字典(权重和偏置),而不是整个模型。这可以通过调用model.state_dict()
方法并将其传递给torch.save()
函数来实现。
要加载保存的模型,可以使用以下代码:
# 加载整个模型
model_loaded = torch.load('resnet18_model.pth')
# 加载模型的状态字典
model = models.resnet18(pretrained=False) # 创建一个新的ResNet18模型实例
model.load_state_dict(torch.load('resnet18_state_dict.pth'))
在这个例子中,我们首先使用torch.load()
函数加载保存的模型。然后,我们创建了一个新的ResNet18模型实例,并使用load_state_dict()
方法将之前保存的状态字典加载到新模型中。注意,在加载状态字典时,我们需要确保新模型的架构与保存的模型架构相同。