Pytorch_数据集的创建和加载

PyTorch通过torch.utils.data对一般常用的数据加载进行了封装,可以很容易地实现多线程数据预读和批量加载。

可以通过dataset定义数据集,并使用Datalorder载入和遍历数据集

Dataset

Dataset是一个抽象类,为了能够方便的读取,需要将要使用的数据包装为Dataset类。

自定义的Dataset需要继承它并且实现两个成员方法:

  1. __getitem__() 该方法定义用索引(0len(self))获取一条数据或一个样本
  2. __len__() 该方法返回数据集的总长度

下面使用kaggle上的一个竞赛bluebook for bulldozers自定义一个数据集,用里面的数据字典来做说明(因为条数少)

from torch.utils.data import Dataset
import pandas as pd

#定义一个数据集
class BulldozerDataset(Dataset):
    
    # 实现初始化方法,在初始化的时候将数据读载入
    def __init__(self, csv_file):
        self.df=pd.read_csv(csv_file)
    
    # 返回df的长度
    def __len__(self):
        return len(self.df)

    # 根据 idx 返回一行数据
    def __getitem__(self, idx):
        return self.df.iloc[idx].SalePrice

至此,我们的数据集已经定义完成了,我们可以实例话一个对象访问他

ds_demo = BulldozerDataset('median_benchhmark.csv')
print(len(ds_demo)) # 11573
print(ds_demo[0]) # 24000.0

Dataloader

DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小)、 shuffle(是否进行shuffle操作)、 num_workers(加载数据的时候使用几个子进程)。下面做一个简单的操作

dl = torch.utils.data.DataLoader(ds_demo, batch_size=10, shuffle=True, num_workers=0)

DataLoader返回的是一个可迭代对象,我们可以使用迭代器分次获取数据

idata=iter(dl)
print(next(idata))
# Output:
# tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.])

常见的用法是使用for循环对其进行遍历

for i, data in enumerate(dl):
    print(i,data)
    break
Licensed under CC BY-NC-SA 4.0
comments powered by Disqus