PyTorch DataLoader 使用

在 PyTorch 中,使用 torch.utils.data.DataLoader 类可以实现批量的数据集加载,在我们训练模型中非常常用,其功能也确实比较强度大。由于其参数比较多,我们将会对其用法进行详解。

  1. DataLoader 的基本使用
  2. DataLoader 的 collate_fn 参数
  3. DataLoader 的 sampler 参数

1. DataLoader 的基本使用

使用 DataLoader 前,我们先实现一个用于获得数据的类,假设名字为: MyDataset,其需要实现以下几个方法:

  1. __init__ 方法用于对类对象进行初始化
  2. __len__ 方法用于返回数据集中样本的数量
  3. __getitem__ 方法用于根据索引返回一条样本

接下来,将我们自己构造的 MyDataset 实例对象交给 DataLoader,由其对我们的数据集对象进行封装返回一个数据加载器。DataLoader 的 shuffle 参数可以指定是否打乱原有的数据集顺序,batch_size 参数用于指定每次加载的批次样本数量,drop_last 参数指定最后一组不够批次数量的样本是否丢弃。

shuffle 参数默认值是 False ,batch_size 参数默认值是 1,drop_last 参数默认值是 False。

import torch
from torch.utils.data import DataLoader


class MyDataset:

    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.sample_number = len(self.y)

    def __len__(self):
        return self.sample_number

    def __getitem__(self, idx):

        # 修正 idx 范围为 [0, idx]
        idx = min(max(idx, 0), self.sample_number - 1)
        # 返回一组样本
        return self.x[idx], self.y[idx]


def test():

    # 构造数据集
    x = torch.arange(21).reshape(21, 1)
    y = torch.arange(21)
    # 初始化数据集
    dataset = MyDataset(x, y)
    # 初始化数据加载器
    dataloader = DataLoader(dataset, shuffle=True, batch_size=8, drop_last=True)

    for tx, ty in dataloader:
        print(tx)


if __name__ == '__main__':
    test()

程序输出结果:

tensor([[ 7],
        [14],
        [11],
        [15],
        [ 9],
        [12],
        [ 3],
        [ 5]])
tensor([[ 8],
        [10],
        [ 1],
        [ 4],
        [13],
        [ 6],
        [ 0],
        [ 2]])

从程序可以看到,我们的样本数量为 21,每 8 个样本组成一个批次,由于设置了 drop_last 为 True,所以共打印了 2 个批次的训练数据,并且由于 shuffle 参数被设置为 True,每一个批次的样本都是被打乱的,并不是按照原来的样本数量。

注意:在上面的例子中,MyDataset 类并没有继承 torch.utils.data.Dataset 类。

2. DataLoader 的 collate_fn 参数

collate_fn 参数用于接收用于传递的一个函数。DataLoader 会从数据集中获得一个批次的数据,然后将该批次数据再传递到 collate_fn 指向的函数中进行二次处理。

import torch
from torch.utils.data import DataLoader


class MyDataset:

    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.sample_number = len(self.y)

    def __len__(self):
        return self.sample_number

    def __getitem__(self, idx):

        # 修正 idx 范围为 [0, idx]
        idx = min(max(idx, 0), self.sample_number - 1)
        # 返回一组样本
        return self.x[idx], self.y[idx]


def secondary_processing(data):
    # 在此函数中可以对数据集进行二次处理
    # 传递进行的批次数据 [(样本1, 目标值1) ... (样本2, 目标值2)]
    feature = []
    target = []

    for x, y in data:
        feature.append(x.tolist())
        target.append(y.item())

    feature = torch.tensor(feature)
    target = torch.tensor(target)

    return feature, target


def test():

    # 构造数据集
    x = torch.arange(16).reshape(16, 1)
    y = torch.arange(100, 116)
    # 初始化数据集
    dataset = MyDataset(x, y)
    # 初始化数据加载器
    dataloader = DataLoader(dataset,
                            shuffle=True,
                            batch_size=8,
                            collate_fn=secondary_processing)

    for tx, ty in dataloader:
        print(tx)


if __name__ == '__main__':
    test()

3. DataLoader 的 sampler 参数

sampler 用于设置如何从数据集中提取样本,即: 数据采样策略。如果指定该参数,则 shuffle 参数则会被忽略。在 DataLoader 中内置了几种采样器:

  1. SequentialSampler 采样策略表示按照样本的顺序进行采样
  2. BatchSampler 采样策略表示按照指定的批次索引进行采样
  3. RandomSampler 采样策略表示进行随机采样、以及是否允许有放回的采样
  4. SubsetRandomSampler 采样策略表示按照指定的集合或者索引列表进行随机采样
  5. WeightedRandomSampler 采样策略表示按照指定的概率进行随机采样

接下来,我们代码演示下上面一些采样策略的用法:

import torch
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
from torch.utils.data import BatchSampler
from torch.utils.data import RandomSampler
from torch.utils.data import SubsetRandomSampler
from torch.utils.data import WeightedRandomSampler


class MyDataset:

    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.sample_number = len(self.y)

    def __len__(self):
        return self.sample_number

    def __getitem__(self, idx):

        if isinstance(idx, int):
            idx = min(max(idx, 0), self.sample_number - 1)
            return self.x[idx], self.y[idx]

        if isinstance(idx, list):
            xs = []
            ys = []
            for i in idx:
                xs.append(self.x[i])
                ys.append(self.y[i])

            return xs, ys


# 1. SequentialSampler 的用法
def get_dataloader1(x, y):
    dataset = MyDataset(x, y)
    # SequentialSampler 需要将 dataset 作为参数
    # SequentialSampler 获得原始数据的索引
    sampler = SequentialSampler(dataset)
    # 由于 SequentialSampler 没有指定 batch_size 的参数, 需要在 DataLoader 中设置
    dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)

    return dataloader


# 2. BatchSampler 的用法
def get_dataloader2(x, y):
    dataset = MyDataset(x, y)
    # 在指定索引列表, 根据 batch_size 产生顺序产生批次数据
    # [3, 4, 5]、[7, 8, 9]、[10] 作为一个批次
    sampler = BatchSampler([3, 4, 5, 7, 8, 9, 10], batch_size=3, drop_last=False)
    # 由于 BatchSampler 指定了 batch_size, 在 DataLoader 中不需要指定
    dataloader = DataLoader(dataset, sampler=sampler)
    return dataloader


# 3. RandomSampler 的用法
def get_dataloader3(x, y):
    dataset = MyDataset(x, y)
    # RandomSampler 需要将 dataset 作为参数
    # RandomSampler 获得原始数据的索引
    sampler = RandomSampler(dataset)
    dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)

    return dataloader

def get_dataloader4(x, y):
    dataset = MyDataset(x, y)
    # 随机从 [3, 4, 5, 7, 8, 9, 10] 中产生批次
    sampler = SubsetRandomSampler([3, 4, 5, 7, 8, 9, 10])
    dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
    return dataloader

def get_dataloader5(x, y):
    dataset = MyDataset(x, y)
    # 随机从前 num_samples 个样本中,根据概率值中产生批次
    # replacement 参数表示否重复采样
    sampler = WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6, 0.4, 0.7, 3.0, 0.6],
                                    num_samples=10,
                                    replacement=False)
    dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
    return dataloader


def test():

    x = torch.arange(16).reshape(16, 1)
    y = torch.arange(100, 116)
    dataloader = get_dataloader5(x,y)
    for tx, ty in dataloader:
        print(tx, ty)


if __name__ == '__main__':
    test()

4. DataLoader Dataset 和 sampler 的关系

DataLoader 使用 sampler 产生数据索引,根据索引从 Dataset 中获得批次数据。

未经允许不得转载:一亩三分地 » PyTorch DataLoader 使用
评论 (0)

9 + 6 =