在PyTorch中自定义数据集需要继承torch.utils.data.Dataset类,并实现以下方法:
__init__(self, ...):初始化方法,可以在这里加载数据或设置数据路径等。__len__(self):返回数据集的大小。__getitem__(self, index):根据索引返回数据样本。以下是一个例子,假设我们有一个包含图像和标签的数据集:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = {
'image': self.data[index],
'label': self.labels[index]
}
return sample
# 使用自定义数据集
data = [...] # 图像数据
labels = [...] # 图像标签
custom_dataset = CustomDataset(data, labels)
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=64, shuffle=True)
在上面的例子中,CustomDataset类继承了torch.utils.data.Dataset,并实现了__init__、__len__和__getitem__方法。然后我们可以通过创建一个DataLoader对象来加载自定义数据集,以便于后续的训练或测试。