PyTorch Ignite Concepts

Ignite 是一个可以帮助我们在 PyTorch 中训练和评估神经网络的高级库。简单来讲,使用该训练库可以让我们的训练代码更加简洁,灵活。工具的安装命令如下:

pip install pytorch-ignite

Ignite 中主要有以下 4 个重要概念:

  1. Engine
  2. Events And Handler
  3. state
  4. Metrics

1. Engine

我们在编写训练函数时,经常需要编写一个循环嵌套,外层控制 epoch,内层循环控制 iteration。Engine 就是对循环的抽象。例如:

from ignite.engine import Engine
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch


def train_step(engine, batch_data):
    # 参数是由 Engine 传递进来
    # engine 为调用 train_step 函数的 Engine 对象
    # batch_data 为 DataLoader 加载的 batch 数据
    print(engine)
    print(batch_data)
    print('-' * 55)

    y_true = [0, 0]
    y_pred = [1, 0]
    loss = torch.randint(0, 10, size=(1,))

    # 单次迭代之后,返回结果,可以以任何的格式返回任意多的数据
    return {'y_true': y_true, 'y_pred': y_pred, 'loss': loss}


if __name__ == '__main__':

    # 初始化 Engine 对象,并设置循环执行的函数
    trainer = Engine(train_step)
    # run 函数需要传递两个参数
    # 1. 数据加载器 DataLoader 对象
    # 2. 训练轮数 epoch 值

    # 构建数据加载器
    train_data = TensorDataset(torch.randn(10, 1), torch.randint(0, 2, size=(10,)))
    train_data = DataLoader(train_data, batch_size=2, shuffle=True)
    # 执行 max_epochs 数量的 train_step
    # 注意: 并不是执行 max_epochs 次 train_step 函数
    # 函数 train_step 的执行次数 = max_epochs * 每一轮的 iteration
    trainer.run(train_data, max_epochs=2)

2. Events And Handler

Ignite 提供了非常好用的事件机制。所谓事件机制就是在恰当的时间执行某个函数。例如:我们在训练过程中,其实就会存在 6 个比较重要的时间点,如下所示:

  1. 训练开始前
  2. 训练结束后
  3. 每一个 epoch 开始前
  4. 每一个 epoch 结束后
  5. 每一个 iteration 开始前
  6. 每一个 iteration 结束后

如果我们想在某个时间点执行某个函数,比如在每一个 iteration 结束后,打印下损失值。就可以先编写一个对应的功能函数,并将其按照事件类型注册到 engine 内部。注册事件有两种方式:

  1. 装饰器方式
  2. 函数调用方式

下面演示了装饰器方式:

from ignite.engine import Engine
from ignite.engine import Events


def train_step(engine, batch_data):
    return 0.1314


# 初始化训练器
trainer = Engine(train_step)

# 表示将 train_start 函数注册到 trainer
# 并设置 Events.STARTED 时刻执行该函数
# engine 为调用该事件的 Engine 对象
@trainer.on(Events.STARTED)
def train_start(engine):
    print('训练开始前')

@trainer.on(Events.COMPLETED)
def train_end(engine):
    print('训练结束后')


def event_filter(engine, epoch_idx):
    # epoch_idx 表示 epoch 数量
    # 我们可以在指定 epoch 时做某些操作

    return epoch_idx in [1, 3, 5, 7]

# 在满足 event_filter 条件的 epoch 开始执行下面的操作
# 所以 event_filter 就是一个自定义的触发条件
@trainer.on(Events.EPOCH_STARTED(event_filter=event_filter))
def epoch_start(engine):
    print('EPOCH 开始前 event_filter')


# 第 once epoch 结束后执行一次,后续不再执行
@trainer.on(Events.EPOCH_COMPLETED(once=3))
def epoch_start(engine):
    print('EPOCH 结束后 once=3')


@trainer.on(Events.ITERATION_STARTED)
def iteration_start(engine):
    print('ITERATION 开始前')


# 每 every 个 epoch 后执行该事件函数
@trainer.on(Events.ITERATION_COMPLETED(every=2))
def iteration_start(engine):
    print('ITERATION 结束后 every=2')


if __name__ == '__main__':

    # 我们这里为了对事件机制举例就不传入实际的数据
    # 此时,需要设置 epoch_length 值
    # 否则 Engine 不知道每个 epoch 循环多少次
    trainer.run(max_epochs=9, epoch_length=4)
  1. 每个输出:”ITERATION 结束后” 之前会执行 2 次 “ITERATION 开始前”,表示每迭代 2 次执行一次该事件
  2. 在第 3 个 epoch 结束输出了 “EPOCH 结束后 once=3” ,此后再也没有输出,表示该事件只会触发一次
  3. 在第 [1, 3, 5 ,7] epoch 输出 “EPOCH 开始前 event_filter”,表示自定义的事件在合适的时间发挥了作用
训练开始前
EPOCH 开始前 event_filter
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
EPOCH 开始前 event_filter
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
EPOCH 结束后 once=3
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
EPOCH 开始前 event_filter
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
EPOCH 开始前 event_filter
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
ITERATION 开始前
ITERATION 开始前
ITERATION 结束后 every=2
训练结束后

下面演示函数的方式:

if __name__ == '__main__':

    # 初始化训练器
    trainer = Engine(train_step)
    # 第一个参数为事件名称
    # 第二个参数为事件函数
    # 第三个参数为传递给事件函数的参数
    trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine, data: print(data), [10, 20, 30])


    # 我们这里为了对事件机制举例就不传入实际的数据
    # 此时,需要设置 epoch_length 值
    # 否则 Engine 不知道每个 epoch 循环多少次
    trainer.run(max_epochs=9, epoch_length=4)

3. State

在 Engine 中存储了一些有用的信息,在我们训练过程中是需要的,下面为示例代码:

from ignite.engine import Engine
from ignite.engine import Events


def train_step(engine, batch_data):
    return {'y_pred': [1, 0], 'y_true': [0, 1], 'loss': 0.1314}


def iteration_start(engine):
    print('最大训练轮数:', trainer.state.max_epochs)
    print('当前训练轮数:', trainer.state.epoch)
    print('当前迭代次数:', trainer.state.iteration)
    print('轮次迭代次数:', trainer.state.epoch_length)
    print('当前迭代输出:', trainer.state.output)
    print('-' * 64)


if __name__ == '__main__':

    trainer = Engine(train_step)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, iteration_start)
    trainer.run(max_epochs=2, epoch_length=2)

程序输出结果:

最大训练轮数: 2
当前训练轮数: 1
当前迭代次数: 1
轮次迭代次数: 2
当前迭代输出: {'y_pred': [1, 0], 'y_true': [0, 1], 'loss': 0.1314}
----------------------------------------------------------------
最大训练轮数: 2
当前训练轮数: 1
当前迭代次数: 2
轮次迭代次数: 2
当前迭代输出: {'y_pred': [1, 0], 'y_true': [0, 1], 'loss': 0.1314}
----------------------------------------------------------------
最大训练轮数: 2
当前训练轮数: 2
当前迭代次数: 3
轮次迭代次数: 2
当前迭代输出: {'y_pred': [1, 0], 'y_true': [0, 1], 'loss': 0.1314}
----------------------------------------------------------------
最大训练轮数: 2
当前训练轮数: 2
当前迭代次数: 4
轮次迭代次数: 2
当前迭代输出: {'y_pred': [1, 0], 'y_true': [0, 1], 'loss': 0.1314}
----------------------------------------------------------------

4. Metrics

我们可以在 trainer 的 EPOCH_COMPLETED 事件函数中,来 run evaluator 来获得评估结果,并打印。

from ignite.metrics import Loss
from ignite.metrics import Accuracy
from ignite.engine import Engine
from ignite.engine import Events
import torch
import torch.nn as nn


def evaluation_step(engine, batch_data):
    # model.eval()
    # with torch.no_grad():
    #     pass

    y_pred = torch.tensor([[0.2, 0.8], [0.6, 0.4]])
    y_true = torch.tensor([1, 1])

    # 注意: 需要返回 tensor 类型
    return y_pred, y_true


def print_metric_result(engine):
    print('y_pred:', engine.state.output)
    print('y_loss:', engine.state.metrics['loss'])
    print('y_accu:', engine.state.metrics['accu'])


if __name__ == '__main__':

    evaluator = Engine(evaluation_step)

    # 构建评估指标
    loss = Loss(nn.CrossEntropyLoss())
    accu = Accuracy()
    # 将评估指标添加到评估器中
    # 第二个参数获取评估结果时,通过 engine.state.metric['accuracy'] 来获取结果
    # 注意: evaluation_step 负责返回 Accuracy 需要的数据
    loss.attach(evaluator, 'loss')
    accu.attach(evaluator, 'accu')

    evaluator.add_event_handler(Events.EPOCH_COMPLETED, print_metric_result)
    evaluator.run(epoch_length=2)

程序输出结果:

y_pred: (tensor([[0.2000, 0.8000],
        [0.6000, 0.4000]]), tensor([1, 1]))
y_loss: 0.6178134083747864
y_accu: 0.5
未经允许不得转载:一亩三分地 » PyTorch Ignite Concepts