在Ubuntu下使用PyTorch加载数据时,可参考以下技巧提升效率:
DataLoader的num_workers参数设置并行加载的子进程数(建议设置为CPU核心数的1-2倍,如num_workers=4),避免数据加载成为瓶颈。pin_memory=True加速数据从CPU到GPU的传输。turbojpeg)替代默认的PIL,或通过torchvision.transforms进行批量预处理。.contiguous()操作优化张量存储。import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
# 自定义数据集
class CustomDataset(Dataset):
def __init__(self, data_path, transform=None):
self.data_path = data_path
self.transform = transform
self.image_paths = [...] # 填充图像路径列表
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('RGB')
if self.transform:
image = self.transform(image)
return image, 0 # 假设标签为0
# 数据加载
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = CustomDataset(data_path='./data', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
# 训练循环中使用
for batch in dataloader:
images, labels = batch
images, labels = images.cuda(), labels.cuda() # 移动到GPU
# 模型训练代码...
nvidia-smi监控GPU使用情况,htop监控CPU和内存,及时调整num_workers避免资源竞争。