您好,登录后才能下订单哦!
# PyTorch怎么实现对猫狗二分类训练集进行读取
## 1. 引言
在深度学习领域,图像分类是最基础且应用最广泛的任务之一。猫狗分类作为经典的二分类问题,常被用于教学和算法验证。PyTorch作为当前最流行的深度学习框架之一,提供了完整的工具链来处理这类任务。本文将详细介绍如何使用PyTorch实现猫狗二分类训练集的读取,涵盖从数据准备到模型训练的全流程。
### 1.1 为什么选择PyTorch
PyTorch具有以下优势:
- 动态计算图(Dynamic Computation Graph)
- 简洁直观的API设计
- 活跃的社区支持
- 与Python生态完美融合
- 完善的GPU加速支持
### 1.2 文章结构
本文将按照以下逻辑展开:
1. 数据集准备与目录结构
2. PyTorch数据读取核心组件
3. 自定义数据集类实现
4. 数据增强与预处理
5. 数据加载器配置
6. 完整代码示例
7. 常见问题与解决方案
---
## 2. 数据集准备与目录结构
### 2.1 获取标准数据集
推荐使用Kaggle的"Dogs vs Cats"数据集:
```bash
kaggle competitions download -c dogs-vs-cats
解压后应包含以下结构:
data/
├── train/
│   ├── cat.0.jpg
│   ├── cat.1.jpg
│   ├── ...
│   ├── dog.0.jpg
│   ├── dog.1.jpg
│   └── ...
└── test/
    ├── 0.jpg
    ├── 1.jpg
    └── ...
建议采用以下规范结构:
custom_data/
├── train/
│   ├── cat/
│   │   ├── cat001.jpg
│   │   └── ...
│   └── dog/
│       ├── dog001.jpg
│       └── ...
└── val/
    ├── cat/
    └── dog/
典型配置: - 训练集:20,000张(猫狗各10,000) - 验证集:5,000张(猫狗各2,500) - 测试集:12,500张(无标签)
基类定义:
class Dataset(Generic[T_co]):
    def __getitem__(self, index) -> T_co:
        ...
    
    def __len__(self) -> int:
        ...
关键参数:
DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=False
)
常用变换:
transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
class CatDogDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['cat', 'dog']
        self.samples = []
        
        for class_idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                self.samples.append(
                    (os.path.join(class_dir, img_name), class_idx)
                )
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, label
class CachedCatDogDataset(CatDogDataset):
    def __init__(self, root_dir, transform=None, cache_size=1000):
        super().__init__(root_dir, transform)
        self.cache = {}
        self.cache_size = cache_size
        
    def __getitem__(self, idx):
        if idx in self.cache:
            return self.cache[idx]
            
        img, label = super().__getitem__(idx)
        
        if len(self.cache) < self.cache_size:
            self.cache[idx] = (img, label)
            
        return img, label
from torchvision import transforms
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.2
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])
from albumentations import (
    HorizontalFlip, Rotate, RandomBrightnessContrast,
    HueSaturationValue, Compose
)
import cv2
def albumentations_transform():
    return Compose([
        HorizontalFlip(p=0.5),
        Rotate(limit=15, p=0.5),
        RandomBrightnessContrast(p=0.2),
        HueSaturationValue(
            hue_shift_limit=20,
            sat_shift_limit=30,
            val_shift_limit=20,
            p=0.5
        )
    ])
class AlbumentationsDataset(CatDogDataset):
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
            
        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        return image, label
from torch.utils.data import DataLoader
train_dataset = CatDogDataset(
    root_dir='data/train',
    transform=train_transform
)
val_dataset = CatDogDataset(
    root_dir='data/val',
    transform=val_transform
)
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4
)
def collate_fn(batch):
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    return torch.stack(images), torch.tensor(labels)
loader = DataLoader(
    dataset,
    batch_size=64,
    collate_fn=collate_fn
)
from torch.utils.data.sampler import WeightedRandomSampler
class_counts = [10000, 10000]  # 猫狗样本数
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
samples_weights = weights[labels]
sampler = WeightedRandomSampler(
    weights=samples_weights,
    num_samples=len(samples_weights),
    replacement=True
)
loader = DataLoader(
    dataset,
    batch_size=64,
    sampler=sampler
)
import os
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
# 1. 定义数据集类
class CatDogDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['cat', 'dog']
        self.samples = []
        
        for class_idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                self.samples.append(
                    (os.path.join(class_dir, img_name), class_idx)
                )
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, label
# 2. 定义数据变换
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 3. 创建数据集和数据加载器
train_dataset = CatDogDataset('data/train', train_transform)
val_dataset = CatDogDataset('data/val', val_transform)
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)
val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4
)
# 4. 定义模型
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# 5. 训练循环
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
    model.train()
    train_loss = 0.0
    
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    # 验证循环
    model.eval()
    val_loss = 0.0
    correct = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, preds = torch.max(outputs, 1)
            correct += torch.sum(preds == labels.data)
    
    print(f'Epoch {epoch+1}: '
          f'Train Loss: {train_loss/len(train_loader):.4f} '
          f'Val Loss: {val_loss/len(val_loader):.4f} '
          f'Val Acc: {correct.double()/len(val_dataset):.4f}')
症状: - 出现”CUDA out of memory”错误 - 训练过程频繁崩溃
解决方案:
1. 减小batch_size(如从64降到32)
2. 使用梯度累积:
accumulation_steps = 4
for i, (inputs, labels) in enumerate(train_loader):
    outputs = model(inputs)
    loss = criterion(outputs, labels) / accumulation_steps
    loss.backward()
    
    if (i+1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
优化方案:
1. 增加num_workers(通常设为CPU核心数的2-4倍)
2. 启用pin_memory=True
3. 使用更快的存储(如NVMe SSD)
4. 预加载部分数据:
from prefetch_generator import BackgroundGenerator
class DataLoaderX(DataLoader):
    def __iter__(self):
        return BackgroundGenerator(super().__iter__())
处理方法: 1. 加权损失函数:
class_weights = torch.tensor([1.0, 2.0])  # 假设狗样本较少
criterion = nn.CrossEntropyLoss(weight=class_weights)
解决方案:
1. 统一resize到固定尺寸
2. 使用RandomResizedCrop增强鲁棒性
3. 动态padding:
transforms.Compose([
    transforms.Resize(256),
    transforms.Pad(0, fill=0, padding_mode='constant'),
    transforms.CenterCrop(224),
    ...
])
本文详细介绍了PyTorch实现猫狗二分类训练集读取的完整流程,关键点包括:
通过灵活运用PyTorch提供的数据处理工具,我们可以高效地构建适合深度学习训练的数据管道。良好的数据读取实现不仅能提升训练效率,还能通过有效的数据增强提升模型泛化能力。
注意:实际运行时请根据硬件条件调整batch_size和num_workers等参数,完整代码约200行,建议在Jupyter Notebook中分步执行测试。 “`
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。