PyTorch通过torch.utils.data
对一般常用的数据加载进行了封装,可以很容易地实现多线程数据预读和批量加载。
可以通过dataset定义数据集,并使用Datalorder载入和遍历数据集
Dataset
Dataset是一个抽象类,为了能够方便的读取,需要将要使用的数据包装为Dataset类。
自定义的Dataset需要继承它并且实现两个成员方法:
__getitem__()
该方法定义用索引(0
到len(self)
)获取一条数据或一个样本__len__()
该方法返回数据集的总长度
下面使用kaggle上的一个竞赛bluebook for bulldozers自定义一个数据集,用里面的数据字典来做说明(因为条数少)
from torch.utils.data import Dataset
import pandas as pd
#定义一个数据集
class BulldozerDataset(Dataset):
# 实现初始化方法,在初始化的时候将数据读载入
def __init__(self, csv_file):
self.df=pd.read_csv(csv_file)
# 返回df的长度
def __len__(self):
return len(self.df)
# 根据 idx 返回一行数据
def __getitem__(self, idx):
return self.df.iloc[idx].SalePrice
至此,我们的数据集已经定义完成了,我们可以实例话一个对象访问他
ds_demo = BulldozerDataset('median_benchhmark.csv')
print(len(ds_demo)) # 11573
print(ds_demo[0]) # 24000.0
Dataloader
DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size
(每个batch的大小)、 shuffle
(是否进行shuffle操作)、 num_workers
(加载数据的时候使用几个子进程)。下面做一个简单的操作
dl = torch.utils.data.DataLoader(ds_demo, batch_size=10, shuffle=True, num_workers=0)
DataLoader返回的是一个可迭代对象,我们可以使用迭代器分次获取数据
idata=iter(dl)
print(next(idata))
# Output:
# tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.])
常见的用法是使用for循环对其进行遍历
for i, data in enumerate(dl):
print(i,data)
break