在 PyTorch 中,使用 torch.utils.data.DataLoader 类可以实现批量的数据集加载,在我们训练模型中非常常用,其功能也确实比较强度大。由于其参数比较多,我们将会对其用法进行详解。
- DataLoader 的基本使用
- DataLoader 的 collate_fn 参数
- DataLoader 的 sampler 参数
1. DataLoader 的基本使用
使用 DataLoader 前,我们先实现一个用于获得数据的类,假设名字为: MyDataset,其需要实现以下几个方法:
- __init__ 方法用于对类对象进行初始化
- __len__ 方法用于返回数据集中样本的数量
- __getitem__ 方法用于根据索引返回一条样本
接下来,将我们自己构造的 MyDataset 实例对象交给 DataLoader,由其对我们的数据集对象进行封装返回一个数据加载器。DataLoader 的 shuffle 参数可以指定是否打乱原有的数据集顺序,batch_size 参数用于指定每次加载的批次样本数量,drop_last 参数指定最后一组不够批次数量的样本是否丢弃。
shuffle 参数默认值是 False ,batch_size 参数默认值是 1,drop_last 参数默认值是 False。
import torch from torch.utils.data import DataLoader class MyDataset: def __init__(self, x, y): self.x = x self.y = y self.sample_number = len(self.y) def __len__(self): return self.sample_number def __getitem__(self, idx): # 修正 idx 范围为 [0, idx] idx = min(max(idx, 0), self.sample_number - 1) # 返回一组样本 return self.x[idx], self.y[idx] def test(): # 构造数据集 x = torch.arange(21).reshape(21, 1) y = torch.arange(21) # 初始化数据集 dataset = MyDataset(x, y) # 初始化数据加载器 dataloader = DataLoader(dataset, shuffle=True, batch_size=8, drop_last=True) for tx, ty in dataloader: print(tx) if __name__ == '__main__': test()
程序输出结果:
tensor([[ 7], [14], [11], [15], [ 9], [12], [ 3], [ 5]]) tensor([[ 8], [10], [ 1], [ 4], [13], [ 6], [ 0], [ 2]])
从程序可以看到,我们的样本数量为 21,每 8 个样本组成一个批次,由于设置了 drop_last 为 True,所以共打印了 2 个批次的训练数据,并且由于 shuffle 参数被设置为 True,每一个批次的样本都是被打乱的,并不是按照原来的样本数量。
注意:在上面的例子中,MyDataset 类并没有继承 torch.utils.data.Dataset 类。
2. DataLoader 的 collate_fn 参数
collate_fn 参数用于接收用于传递的一个函数。DataLoader 会从数据集中获得一个批次的数据,然后将该批次数据再传递到 collate_fn 指向的函数中进行二次处理。
import torch from torch.utils.data import DataLoader class MyDataset: def __init__(self, x, y): self.x = x self.y = y self.sample_number = len(self.y) def __len__(self): return self.sample_number def __getitem__(self, idx): # 修正 idx 范围为 [0, idx] idx = min(max(idx, 0), self.sample_number - 1) # 返回一组样本 return self.x[idx], self.y[idx] def secondary_processing(data): # 在此函数中可以对数据集进行二次处理 # 传递进行的批次数据 [(样本1, 目标值1) ... (样本2, 目标值2)] feature = [] target = [] for x, y in data: feature.append(x.tolist()) target.append(y.item()) feature = torch.tensor(feature) target = torch.tensor(target) return feature, target def test(): # 构造数据集 x = torch.arange(16).reshape(16, 1) y = torch.arange(100, 116) # 初始化数据集 dataset = MyDataset(x, y) # 初始化数据加载器 dataloader = DataLoader(dataset, shuffle=True, batch_size=8, collate_fn=secondary_processing) for tx, ty in dataloader: print(tx) if __name__ == '__main__': test()
3. DataLoader 的 sampler 参数
sampler 用于设置如何从数据集中提取样本,即: 数据采样策略。如果指定该参数,则 shuffle 参数则会被忽略。在 DataLoader 中内置了几种采样器:
- SequentialSampler 采样策略表示按照样本的顺序进行采样
- BatchSampler 采样策略表示按照指定的批次索引进行采样
- RandomSampler 采样策略表示进行随机采样、以及是否允许有放回的采样
- SubsetRandomSampler 采样策略表示按照指定的集合或者索引列表进行随机采样
- WeightedRandomSampler 采样策略表示按照指定的概率进行随机采样
接下来,我们代码演示下上面一些采样策略的用法:
import torch from torch.utils.data import DataLoader from torch.utils.data import SequentialSampler from torch.utils.data import BatchSampler from torch.utils.data import RandomSampler from torch.utils.data import SubsetRandomSampler from torch.utils.data import WeightedRandomSampler class MyDataset: def __init__(self, x, y): self.x = x self.y = y self.sample_number = len(self.y) def __len__(self): return self.sample_number def __getitem__(self, idx): if isinstance(idx, int): idx = min(max(idx, 0), self.sample_number - 1) return self.x[idx], self.y[idx] if isinstance(idx, list): xs = [] ys = [] for i in idx: xs.append(self.x[i]) ys.append(self.y[i]) return xs, ys # 1. SequentialSampler 的用法 def get_dataloader1(x, y): dataset = MyDataset(x, y) # SequentialSampler 需要将 dataset 作为参数 # SequentialSampler 获得原始数据的索引 sampler = SequentialSampler(dataset) # 由于 SequentialSampler 没有指定 batch_size 的参数, 需要在 DataLoader 中设置 dataloader = DataLoader(dataset, batch_size=8, sampler=sampler) return dataloader # 2. BatchSampler 的用法 def get_dataloader2(x, y): dataset = MyDataset(x, y) # 在指定索引列表, 根据 batch_size 产生顺序产生批次数据 # [3, 4, 5]、[7, 8, 9]、[10] 作为一个批次 sampler = BatchSampler([3, 4, 5, 7, 8, 9, 10], batch_size=3, drop_last=False) # 由于 BatchSampler 指定了 batch_size, 在 DataLoader 中不需要指定 dataloader = DataLoader(dataset, sampler=sampler) return dataloader # 3. RandomSampler 的用法 def get_dataloader3(x, y): dataset = MyDataset(x, y) # RandomSampler 需要将 dataset 作为参数 # RandomSampler 获得原始数据的索引 sampler = RandomSampler(dataset) dataloader = DataLoader(dataset, batch_size=8, sampler=sampler) return dataloader def get_dataloader4(x, y): dataset = MyDataset(x, y) # 随机从 [3, 4, 5, 7, 8, 9, 10] 中产生批次 sampler = SubsetRandomSampler([3, 4, 5, 7, 8, 9, 10]) dataloader = DataLoader(dataset, batch_size=4, sampler=sampler) return dataloader def get_dataloader5(x, y): dataset = MyDataset(x, y) # 随机从前 num_samples 个样本中,根据概率值中产生批次 # replacement 参数表示否重复采样 sampler = WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6, 0.4, 0.7, 3.0, 0.6], num_samples=10, replacement=False) dataloader = DataLoader(dataset, batch_size=4, sampler=sampler) return dataloader def test(): x = torch.arange(16).reshape(16, 1) y = torch.arange(100, 116) dataloader = get_dataloader5(x,y) for tx, ty in dataloader: print(tx, ty) if __name__ == '__main__': test()
4. DataLoader Dataset 和 sampler 的关系
DataLoader 使用 sampler 产生数据索引,根据索引从 Dataset 中获得批次数据。