pytorch

pytorch读取csv数据集的方法有哪些

小亿
176
2024-03-30 13:31:04
栏目: 深度学习

在PyTorch中读取CSV数据集通常有以下几种方法:

  1. 使用Pandas库读取CSV文件,并将其转换为PyTorch张量:
import pandas as pd
import torch

# 读取CSV文件
data = pd.read_csv('data.csv')

# 将数据转换为PyTorch张量
tensor_data = torch.tensor(data.values)
  1. 使用PyTorch的Dataset和DataLoader类来读取CSV文件:
import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return torch.tensor(self.data.iloc[idx].values)

dataset = MyDataset('data.csv')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
  1. 使用自定义的数据加载器来读取CSV文件:
import torch

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file):
        data = pd.read_csv(csv_file)
        self.X = torch.tensor(data.iloc[:, :-1].values, dtype=torch.float32)
        self.y = torch.tensor(data.iloc[:, -1].values, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

dataset = CustomDataset('data.csv')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

这些是一些常用的方法,你可以根据自己的需求选择适合的方法来读取CSV数据集。

0
看了该问题的人还看了