linux

Linux中PyTorch模型怎么保存

小樊
35
2025-06-01 04:16:47
栏目: 智能运维

在Linux系统中,使用PyTorch保存模型的主要方法是使用torch.save()函数。以下是一个简单的示例:

import torch
import torchvision.models as models

# 创建一个预训练的ResNet18模型
model = models.resnet18(pretrained=True)

# 保存整个模型
torch.save(model, 'resnet18_model.pth')

# 如果只想保存模型的状态字典(即权重和偏置),可以使用以下方法:
torch.save(model.state_dict(), 'resnet18_state_dict.pth')

在这个例子中,我们首先导入了torch库和torchvision.models模块。然后,我们创建了一个预训练的ResNet18模型。接下来,我们使用torch.save()函数将整个模型保存到一个名为resnet18_model.pth的文件中。此外,我们还可以选择仅保存模型的状态字典(权重和偏置),而不是整个模型。这可以通过调用model.state_dict()方法并将其传递给torch.save()函数来实现。

要加载保存的模型,可以使用以下代码:

# 加载整个模型
model_loaded = torch.load('resnet18_model.pth')

# 加载模型的状态字典
model = models.resnet18(pretrained=False)  # 创建一个新的ResNet18模型实例
model.load_state_dict(torch.load('resnet18_state_dict.pth'))

在这个例子中,我们首先使用torch.load()函数加载保存的模型。然后,我们创建了一个新的ResNet18模型实例,并使用load_state_dict()方法将之前保存的状态字典加载到新模型中。注意,在加载状态字典时,我们需要确保新模型的架构与保存的模型架构相同。

0
看了该问题的人还看了