您好,登录后才能下订单哦!
密码登录
            
            
            
            
        登录注册
            
            
            
        点击 登录注册 即表示同意《亿速云用户服务条款》
        以下是以《PyTorch中如何实现ResNet结构》为标题的Markdown格式文章,约10700字:
# PyTorch中如何实现ResNet结构
## 1. 引言
### 1.1 ResNet的背景与意义
残差网络(Residual Network, ResNet)是由微软研究院的何恺明等人于2015年提出的深度卷积神经网络架构。它在计算机视觉领域具有里程碑式的意义,主要解决了深度神经网络训练中的"退化问题"(degradation problem):随着网络深度的增加,准确率达到饱和后迅速下降。
传统观点认为更深的网络应该能够学习到更复杂的特征表示,从而获得更好的性能。然而实验表明,单纯的增加网络深度会导致梯度消失/爆炸问题,即使通过归一化初始化等手段解决了梯度问题,网络性能仍会下降。ResNet通过引入"残差学习"(residual learning)的概念,使得网络能够更容易地学习恒等映射,从而让超深层网络的训练成为可能。
### 1.2 ResNet的主要贡献
1. 提出了残差学习框架,解决了深度网络的退化问题
2. 通过捷径连接(shortcut connection)实现恒等映射
3. 在ImageNet和COCO等数据集上取得当时最佳性能
4. 网络深度可轻松扩展到100层以上(ResNet-152)
### 1.3 PyTorch实现的意义
PyTorch作为当前主流的深度学习框架之一,具有以下优势:
- 动态计算图,更灵活的模型构建方式
- 简洁直观的API设计
- 强大的GPU加速支持
- 活跃的社区和丰富的生态系统
通过PyTorch实现ResNet不仅有助于理解其核心思想,还能掌握现代深度学习框架的实际应用技巧。
## 2. ResNet核心原理
### 2.1 残差学习的基本思想
残差学习的核心公式:
$$
\mathbf{y} = \mathcal{F}(\mathbf{x}, \{W_i\}) + \mathbf{x}
$$
其中:
- $\mathbf{x}$ 和 $\mathbf{y}$ 是输入和输出
- $\mathcal{F}(\mathbf{x}, \{W_i\})$ 是要学习的残差映射
- 加法操作通过快捷连接实现
当理想的映射$H(\mathbf{x})$较复杂时,让网络学习残差$\mathcal{F}(\mathbf{x}) = H(\mathbf{x}) - \mathbf{x}$通常更为容易。
### 2.2 残差块(Residual Block)设计
ResNet的基础构建块是残差块,主要有两种类型:
1. **基本块(BasicBlock)**:用于较浅的网络(如ResNet-18/34)
   ```python
   class BasicBlock(nn.Module):
       expansion = 1
       
       def __init__(self, in_channels, out_channels, stride=1):
           super().__init__()
           self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                                 stride=stride, padding=1, bias=False)
           self.bn1 = nn.BatchNorm2d(out_channels)
           self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                                 stride=1, padding=1, bias=False)
           self.bn2 = nn.BatchNorm2d(out_channels)
           
           self.shortcut = nn.Sequential()
           if stride != 1 or in_channels != self.expansion * out_channels:
               self.shortcut = nn.Sequential(
                   nn.Conv2d(in_channels, self.expansion * out_channels,
                            kernel_size=1, stride=stride, bias=False),
                   nn.BatchNorm2d(self.expansion * out_channels)
               )
瓶颈块(BottleneckBlock):用于更深的网络(如ResNet-50/101/152)
class BottleneckBlock(nn.Module):
   expansion = 4
   def __init__(self, in_channels, out_channels, stride=1):
       super().__init__()
       self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
       self.bn1 = nn.BatchNorm2d(out_channels)
       self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                             stride=stride, padding=1, bias=False)
       self.bn2 = nn.BatchNorm2d(out_channels)
       self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, 
                             kernel_size=1, bias=False)
       self.bn3 = nn.BatchNorm2d(self.expansion * out_channels)
       self.shortcut = nn.Sequential()
       if stride != 1 or in_channels != self.expansion * out_channels:
           self.shortcut = nn.Sequential(
               nn.Conv2d(in_channels, self.expansion * out_channels,
                        kernel_size=1, stride=stride, bias=False),
               nn.BatchNorm2d(self.expansion * out_channels)
           )
ResNet有多种变体,主要区别在于层数和块类型:
| 模型名称 | 层数 | 残差块类型 | 参数量(M) | 
|---|---|---|---|
| ResNet-18 | 18 | BasicBlock | 11.7 | 
| ResNet-34 | 34 | BasicBlock | 21.8 | 
| ResNet-50 | 50 | Bottleneck | 25.6 | 
| ResNet-101 | 101 | Bottleneck | 44.5 | 
| ResNet-152 | 152 | Bottleneck | 60.2 | 
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, 
            stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3,
            stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
class BottleneckBlock(nn.Module):
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride=1):
        super(BottleneckBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 
                              kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 
                              kernel_size=3, stride=stride, 
                              padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, 
                              kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 
                              stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        # 初始化权重
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', 
                                      nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
def ResNet18(num_classes=1000):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
def ResNet34(num_classes=1000):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
def ResNet50(num_classes=1000):
    return ResNet(BottleneckBlock, [3, 4, 6, 3], num_classes)
def ResNet101(num_classes=1000):
    return ResNet(BottleneckBlock, [3, 4, 23, 3], num_classes)
def ResNet152(num_classes=1000):
    return ResNet(BottleneckBlock, [3, 8, 36, 3], num_classes)
快捷连接处理:
下采样策略:
批量归一化:
权重初始化:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 训练数据增强
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])
# 验证集转换
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.ImageFolder('path/to/train', train_transform)
val_dataset = datasets.ImageFolder('path/to/val', val_transform)
train_loader = DataLoader(train_dataset, batch_size=64, 
                         shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64,
                      shuffle=False, num_workers=4)
import torch.optim as optim
from tqdm import tqdm
def train(model, device, train_loader, optimizer, epoch, criterion):
    model.train()
    total_loss = 0
    correct = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
    for data, target in pbar:
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        
        pbar.set_postfix({'loss': loss.item()})
    
    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / len(train_loader.dataset)
    return avg_loss, accuracy
def validate(model, device, val_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    avg_loss = total_loss / len(val_loader)
    accuracy = 100. * correct / len(val_loader.dataset)
    return avg_loss, accuracy
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = ResNet50(num_classes=1000).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.1, 
                        momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    
    best_acc = 0.0
    for epoch in range(1, 91):
        train_loss, train_acc = train(model, device, train_loader, 
                                     optimizer, epoch, criterion)
        val_loss, val_acc = validate(model, device, val_loader, criterion)
        scheduler.step()
        
        print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Acc: {train_acc:.2f}% | '
              f'Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%')
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'resnet50_best.pth')
    
    print(f'Best Validation Accuracy: {best_acc:.2f}%')
if __name__ == '__main__':
    main()
学习率调度:
优化器选择:
正则化策略:
混合精度训练: “`python from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
def train(…): … with autocast(): output = model(data) loss = criterion(output, target)
   scaler.scale(loss).backward()
   scaler.step(optimizer)
   scaler.update()
   ...
”`
ResNeXt:
Wide ResNet:
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。