在PyTorch中,预训练的分类任务通常涉及以下步骤:
数据准备:
选择预训练模型:
加载预训练模型:
torchvision.models
模块中的相应函数来加载预训练模型。例如,要加载预训练的ResNet-18模型,你可以使用torchvision.models.resnet18(pretrained=True)
。修改最后一层:
nn.Linear(model.fc.in_features, 10)
来替换最后一层。微调模型:
评估和测试:
下面是一个简单的示例代码,展示了如何使用PyTorch加载预训练的ResNet-18模型并进行分类任务:
import torch
import torchvision.transforms as transforms
import torchvision.models as models
# 数据预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载预训练模型
model = models.resnet18(pretrained=True)
# 修改最后一层
num_classes = 10 # 根据你的任务设置类别数量
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# 微调模型(这里只是示例,实际应用中你需要编写训练循环)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 假设你有一个数据加载器 data_loader
for epoch in range(num_epochs):
for images, labels in data_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 评估模型
# ...
请注意,上述代码只是一个简化的示例,实际应用中你需要根据你的具体任务和数据集来调整数据预处理、模型修改和微调过程。