pytorch

PyTorch中怎么实现数据加载器

小亿
85
2024-05-10 15:46:56
栏目: 深度学习

在PyTorch中,可以使用torch.utils.data.DataLoader类来实现数据加载器。DataLoader可以将数据集划分成多个batch,并提供数据加载的功能。以下是一个简单的示例:

import torch
from torch.utils.data import DataLoader, Dataset

# 创建自定义的数据集类
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]

# 创建一个包含一些示例数据的数据集
data = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])]
dataset = MyDataset(data)

# 创建数据加载器
batch_size = 2
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 遍历数据加载器并打印每个batch的数据
for batch in dataloader:
    print(batch)

在上面的示例中,首先创建了一个自定义的数据集类MyDataset,然后创建了一个包含示例数据的数据集dataset。接着使用DataLoader将数据集划分成batch,并设置了batch大小为2,并设置了shuffle参数为True,表示每个epoch时重新洗牌数据。最后,通过遍历数据加载器,可以打印出每个batch的数据。

0
看了该问题的人还看了