在PyTorch中,对MNIST手写数字分类数据进行预处理的步骤如下:
import torch
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
在这个例子中,我们首先导入了所需的库,然后使用transforms.Compose()
函数定义了一个预处理管道。这个管道包括两个步骤:
transforms.ToTensor()
:将图像数据从(28, 28)的numpy数组转换为PyTorch张量,数值范围从[0, 255]缩放到[0, 1]。transforms.Normalize((0.1307,), (0.3081,))
:对图像数据进行归一化处理,使用MNIST数据集的均值和标准差作为参数。接下来,我们分别加载了训练集和测试集,并使用预处理管道对它们进行了处理。