pytorch中如何实现ResNet结构

发布时间:2022-02-24 09:46:28 作者:小新
来源:亿速云 阅读:245

以下是以《PyTorch中如何实现ResNet结构》为标题的Markdown格式文章,约10700字:

# PyTorch中如何实现ResNet结构

## 1. 引言

### 1.1 ResNet的背景与意义

残差网络(Residual Network, ResNet)是由微软研究院的何恺明等人于2015年提出的深度卷积神经网络架构。它在计算机视觉领域具有里程碑式的意义,主要解决了深度神经网络训练中的"退化问题"(degradation problem):随着网络深度的增加,准确率达到饱和后迅速下降。

传统观点认为更深的网络应该能够学习到更复杂的特征表示,从而获得更好的性能。然而实验表明,单纯的增加网络深度会导致梯度消失/爆炸问题,即使通过归一化初始化等手段解决了梯度问题,网络性能仍会下降。ResNet通过引入"残差学习"(residual learning)的概念,使得网络能够更容易地学习恒等映射,从而让超深层网络的训练成为可能。

### 1.2 ResNet的主要贡献

1. 提出了残差学习框架,解决了深度网络的退化问题
2. 通过捷径连接(shortcut connection)实现恒等映射
3. 在ImageNet和COCO等数据集上取得当时最佳性能
4. 网络深度可轻松扩展到100层以上(ResNet-152)

### 1.3 PyTorch实现的意义

PyTorch作为当前主流的深度学习框架之一,具有以下优势:
- 动态计算图,更灵活的模型构建方式
- 简洁直观的API设计
- 强大的GPU加速支持
- 活跃的社区和丰富的生态系统

通过PyTorch实现ResNet不仅有助于理解其核心思想,还能掌握现代深度学习框架的实际应用技巧。

## 2. ResNet核心原理

### 2.1 残差学习的基本思想

残差学习的核心公式:

$$
\mathbf{y} = \mathcal{F}(\mathbf{x}, \{W_i\}) + \mathbf{x}
$$

其中:
- $\mathbf{x}$ 和 $\mathbf{y}$ 是输入和输出
- $\mathcal{F}(\mathbf{x}, \{W_i\})$ 是要学习的残差映射
- 加法操作通过快捷连接实现

当理想的映射$H(\mathbf{x})$较复杂时,让网络学习残差$\mathcal{F}(\mathbf{x}) = H(\mathbf{x}) - \mathbf{x}$通常更为容易。

### 2.2 残差块(Residual Block)设计

ResNet的基础构建块是残差块,主要有两种类型:

1. **基本块(BasicBlock)**:用于较浅的网络(如ResNet-18/34)
   ```python
   class BasicBlock(nn.Module):
       expansion = 1
       
       def __init__(self, in_channels, out_channels, stride=1):
           super().__init__()
           self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                                 stride=stride, padding=1, bias=False)
           self.bn1 = nn.BatchNorm2d(out_channels)
           self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                                 stride=1, padding=1, bias=False)
           self.bn2 = nn.BatchNorm2d(out_channels)
           
           self.shortcut = nn.Sequential()
           if stride != 1 or in_channels != self.expansion * out_channels:
               self.shortcut = nn.Sequential(
                   nn.Conv2d(in_channels, self.expansion * out_channels,
                            kernel_size=1, stride=stride, bias=False),
                   nn.BatchNorm2d(self.expansion * out_channels)
               )
  1. 瓶颈块(BottleneckBlock):用于更深的网络(如ResNet-50/101/152)

    class BottleneckBlock(nn.Module):
       expansion = 4
    
    
       def __init__(self, in_channels, out_channels, stride=1):
           super().__init__()
           self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
           self.bn1 = nn.BatchNorm2d(out_channels)
           self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                                 stride=stride, padding=1, bias=False)
           self.bn2 = nn.BatchNorm2d(out_channels)
           self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, 
                                 kernel_size=1, bias=False)
           self.bn3 = nn.BatchNorm2d(self.expansion * out_channels)
    
    
           self.shortcut = nn.Sequential()
           if stride != 1 or in_channels != self.expansion * out_channels:
               self.shortcut = nn.Sequential(
                   nn.Conv2d(in_channels, self.expansion * out_channels,
                            kernel_size=1, stride=stride, bias=False),
                   nn.BatchNorm2d(self.expansion * out_channels)
               )
    

2.3 网络架构概览

ResNet有多种变体,主要区别在于层数和块类型:

模型名称 层数 残差块类型 参数量(M)
ResNet-18 18 BasicBlock 11.7
ResNet-34 34 BasicBlock 21.8
ResNet-50 50 Bottleneck 25.6
ResNet-101 101 Bottleneck 44.5
ResNet-152 152 Bottleneck 60.2

3. PyTorch实现详解

3.1 基础实现步骤

  1. 实现残差块(BasicBlock/Bottleneck)
  2. 构建ResNet主体结构
  3. 添加全局平均池化和全连接层
  4. 实现初始化方法

3.2 完整实现代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, 
            stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class BottleneckBlock(nn.Module):
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride=1):
        super(BottleneckBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 
                              kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 
                              kernel_size=3, stride=stride, 
                              padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, 
                              kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 
                              stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        # 初始化权重
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', 
                                      nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

def ResNet18(num_classes=1000):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

def ResNet34(num_classes=1000):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)

def ResNet50(num_classes=1000):
    return ResNet(BottleneckBlock, [3, 4, 6, 3], num_classes)

def ResNet101(num_classes=1000):
    return ResNet(BottleneckBlock, [3, 4, 23, 3], num_classes)

def ResNet152(num_classes=1000):
    return ResNet(BottleneckBlock, [3, 8, 36, 3], num_classes)

3.3 关键实现细节解析

  1. 快捷连接处理

    • 当输入输出维度匹配时,直接相加
    • 维度不匹配时,通过1x1卷积调整维度和空间尺寸
  2. 下采样策略

    • 第一个卷积层使用stride=2的7x7卷积
    • 每个stage的第一个残差块使用stride=2进行下采样
    • 使用max pooling进一步降低特征图尺寸
  3. 批量归一化

    • 每个卷积层后都添加BatchNorm
    • 在ReLU激活函数前进行归一化
  4. 权重初始化

    • 使用He初始化(Kaiming初始化)处理卷积层权重
    • BatchNorm层的γ初始化为1,β初始化为0

4. 模型训练与评估

4.1 数据准备与增强

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 训练数据增强
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

# 验证集转换
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

# 加载数据集
train_dataset = datasets.ImageFolder('path/to/train', train_transform)
val_dataset = datasets.ImageFolder('path/to/val', val_transform)

train_loader = DataLoader(train_dataset, batch_size=64, 
                         shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64,
                      shuffle=False, num_workers=4)

4.2 训练过程实现

import torch.optim as optim
from tqdm import tqdm

def train(model, device, train_loader, optimizer, epoch, criterion):
    model.train()
    total_loss = 0
    correct = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
    for data, target in pbar:
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        
        pbar.set_postfix({'loss': loss.item()})
    
    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / len(train_loader.dataset)
    return avg_loss, accuracy

def validate(model, device, val_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    avg_loss = total_loss / len(val_loader)
    accuracy = 100. * correct / len(val_loader.dataset)
    return avg_loss, accuracy

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = ResNet50(num_classes=1000).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.1, 
                        momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    
    best_acc = 0.0
    for epoch in range(1, 91):
        train_loss, train_acc = train(model, device, train_loader, 
                                     optimizer, epoch, criterion)
        val_loss, val_acc = validate(model, device, val_loader, criterion)
        scheduler.step()
        
        print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}% | '
              f'Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%')
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'resnet50_best.pth')
    
    print(f'Best Validation Accuracy: {best_acc:.2f}%')

if __name__ == '__main__':
    main()

4.3 训练技巧与优化

  1. 学习率调度

    • 初始学习率设为0.1
    • 每30个epoch乘以0.1
  2. 优化器选择

    • 使用带动量的SGD优化器
    • 权重衰减(weight decay)设为1e-4
  3. 正则化策略

    • 使用数据增强防止过拟合
    • 添加Dropout(可选)
  4. 混合精度训练: “`python from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

def train(…): … with autocast(): output = model(data) loss = criterion(output, target)

   scaler.scale(loss).backward()
   scaler.step(optimizer)
   scaler.update()
   ...

”`

5. 高级主题与变体

5.1 ResNet变体介绍

  1. ResNeXt

    • 引入分组卷积
    • 基数(cardinality)作为新维度
    • 公式:\(\mathcal{F}(\mathbf{x}) = \sum_{i=1}^C \mathcal{T}_i(\mathbf{x})\)
  2. Wide ResNet

    • 增加每层的通道数
    • 减少网络深度
    • 更快的训练
推荐阅读:
  1. PyTorch如何实现ResNet50、ResNet101和ResNet152
  2. 怎么在Pytorch中修改ResNet模型全连接层

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

pytorch resnet

上一篇:Pytorch如何实现变量类型转换?

下一篇:pytorch如何实现beam search

相关阅读

您好,登录后才能下订单哦!

密码登录
登录注册
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》