pytorch中如何使用迁移学习resnet18训练mnist数据集

发布时间:2021-12-04 19:07:29 作者:柒染
来源:亿速云 阅读:567
# PyTorch中如何使用迁移学习ResNet18训练MNIST数据集

## 1. 迁移学习概述

迁移学习(Transfer Learning)是深度学习中的重要技术,它允许我们将在一个任务上训练好的模型参数迁移到另一个相关任务中。这种方法特别适用于以下场景:
- 目标数据集较小(如医学图像)
- 计算资源有限
- 需要快速原型开发

在计算机视觉领域,预训练的CNN模型(如ResNet、VGG等)通过迁移学习可以显著提升在小规模数据集上的表现。本文将详细介绍如何使用PyTorch中的ResNet18模型,通过迁移学习技术来训练MNIST手写数字数据集。

## 2. 环境准备与数据加载

### 2.1 安装必要库

```python
!pip install torch torchvision matplotlib

2.2 导入所需模块

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

2.3 数据预处理

MNIST图像是28x28的灰度图像,而ResNet18默认输入是224x224的3通道图像,需要进行调整:

# 定义数据转换
transform = transforms.Compose([
    transforms.Resize(224),  # 调整大小
    transforms.Grayscale(num_output_channels=3),  # 灰度转RGB
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 归一化
])

# 加载数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

3. 模型准备与迁移学习策略

3.1 加载预训练ResNet18

model = models.resnet18(pretrained=True)

3.2 模型结构调整

ResNet18原始输出是1000类(ImageNet),我们需要修改最后一层以适应MNIST的10分类任务:

# 冻结所有卷积层参数
for param in model.parameters():
    param.requires_grad = False

# 替换最后的全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # MNIST有10个类别

3.3 迁移学习策略选择

常见迁移学习策略有: 1. 特征提取器:冻结卷积层,只训练全连接层(本文采用) 2. 微调:解冻部分或全部卷积层进行微调 3. 渐进解冻:逐步解冻网络层

4. 训练过程实现

4.1 定义训练函数

def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        
    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / len(train_loader.dataset)
    
    print(f'Train Epoch: {epoch} \tLoss: {train_loss:.4f} \tAccuracy: {train_acc:.2f}%')
    return train_loss, train_acc

4.2 定义测试函数

def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader)
    test_acc = 100. * correct / len(test_loader.dataset)
    
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {test_acc:.2f}%\n')
    return test_loss, test_acc

4.3 主训练循环

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

epochs = 10
train_losses, train_accs = [], []
test_losses, test_accs = [], []

for epoch in range(1, epochs + 1):
    train_loss, train_acc = train(model, device, train_loader, optimizer, criterion, epoch)
    test_loss, test_acc = test(model, device, test_loader, criterion)
    
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    test_losses.append(test_loss)
    test_accs.append(test_acc)

5. 结果可视化与分析

5.1 训练过程可视化

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(test_accs, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.show()

5.2 性能分析

典型训练结果可能显示: - 测试准确率可达98%以上 - 训练曲线快速收敛 - 过拟合现象不明显(得益于预训练特征)

6. 进阶优化策略

6.1 学习率调整

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

6.2 部分层微调

# 解冻最后两个卷积块
for name, param in model.named_parameters():
    if 'layer4' in name or 'layer3' in name:
        param.requires_grad = True

6.3 数据增强

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

7. 常见问题与解决方案

7.1 输入维度不匹配

问题:ResNet期望3通道输入,MNIST是单通道
解决:使用Grayscale(num_output_channels=3)转换

7.2 过拟合

现象:训练准确率高但测试准确率低
解决方案: - 增加数据增强 - 添加Dropout层 - 使用更小的学习率 - 早停(Early Stopping)

7.3 训练速度慢

优化方案: - 使用更大的batch size - 启用混合精度训练 - 分布式训练

8. 完整代码示例

# 省略部分导入和函数定义...

def main():
    # 数据准备
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.Grayscale(3),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=128, shuffle=False)
    
    # 模型准备
    model = models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    model.fc = nn.Linear(model.fc.in_features, 10)
    
    # 训练配置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
    
    # 训练循环
    for epoch in range(1, 11):
        train(model, device, train_loader, optimizer, criterion, epoch)
        test(model, device, test_loader, criterion)

if __name__ == '__main__':
    main()

9. 总结

本文详细介绍了在PyTorch中使用ResNet18进行迁移学习训练MNIST数据集的全过程,关键点包括: 1. 正确处理单通道图像的输入适配 2. 合理冻结/解冻网络层 3. 针对小数据集的训练技巧 4. 模型性能评估与优化

迁移学习大大降低了在特定任务上训练模型的成本,即使像MNIST这样简单的数据集,使用预训练模型也能获得更好的特征表示和泛化能力。这种方法可以轻松扩展到其他类似任务中。 “`

推荐阅读:
  1. 如何使用tensorflow实现VGG网络,训练mnist数据集
  2. pytorch实现建立自己的数据集(以mnist为例)

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

pytorch resnet18 mnist

上一篇:在PyTorch中backward hook在全连接层和卷积层表现不一致的地方是什么

下一篇:pytorch怎样实现特征图可视化

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》