Ignite 是一个可以帮助我们在 PyTorch 中训练和评估神经网络的高级库。简单来讲,使用该训练库可以让我们的训练代码更加简洁,灵活。工具的安装命令如下:
pip install pytorch-ignite
Ignite 中主要有以下 4 个重要概念:
- Engine
- Events And Handler
- state
- Metrics
1. Engine
我们在编写训练函数时,经常需要编写一个循环嵌套,外层控制 epoch,内层循环控制 iteration。Engine 就是对循环的抽象。例如:
from ignite.engine import Engine from torch.utils.data import DataLoader from torch.utils.data import TensorDataset import torch def train_step(engine, batch_data): # 参数是由 Engine 传递进来 # engine 为调用 train_step 函数的 Engine 对象 # batch_data 为 DataLoader 加载的 batch 数据 print(engine) print(batch_data) print('-' * 55) y_true = [0, 0] y_pred = [1, 0] loss = torch.randint(0, 10, size=(1,)) # 单次迭代之后,返回结果,可以以任何的格式返回任意多的数据 return {'y_true': y_true, 'y_pred': y_pred, 'loss': loss} if __name__ == '__main__': # 初始化 Engine 对象,并设置循环执行的函数 trainer = Engine(train_step) # run 函数需要传递两个参数 # 1. 数据加载器 DataLoader 对象 # 2. 训练轮数 epoch 值 # 构建数据加载器 train_data = TensorDataset(torch.randn(10, 1), torch.randint(0, 2, size=(10,))) train_data = DataLoader(train_data, batch_size=2, shuffle=True) # 执行 max_epochs 数量的 train_step # 注意: 并不是执行 max_epochs 次 train_step 函数 # 函数 train_step 的执行次数 = max_epochs * 每一轮的 iteration trainer.run(train_data, max_epochs=2)
2. Events And Handler
Ignite 提供了非常好用的事件机制。所谓事件机制就是在恰当的时间执行某个函数。例如:我们在训练过程中,其实就会存在 6 个比较重要的时间点,如下所示:
- 训练开始前
- 训练结束后
- 每一个 epoch 开始前
- 每一个 epoch 结束后
- 每一个 iteration 开始前
- 每一个 iteration 结束后
如果我们想在某个时间点执行某个函数,比如在每一个 iteration 结束后,打印下损失值。就可以先编写一个对应的功能函数,并将其按照事件类型注册到 engine 内部。注册事件有两种方式:
- 装饰器方式
- 函数调用方式
下面演示了装饰器方式:
from ignite.engine import Engine from ignite.engine import Events def train_step(engine, batch_data): return 0.1314 # 初始化训练器 trainer = Engine(train_step) # 表示将 train_start 函数注册到 trainer # 并设置 Events.STARTED 时刻执行该函数 # engine 为调用该事件的 Engine 对象 @trainer.on(Events.STARTED) def train_start(engine): print('训练开始前') @trainer.on(Events.COMPLETED) def train_end(engine): print('训练结束后') def event_filter(engine, epoch_idx): # epoch_idx 表示 epoch 数量 # 我们可以在指定 epoch 时做某些操作 return epoch_idx in [1, 3, 5, 7] # 在满足 event_filter 条件的 epoch 开始执行下面的操作 # 所以 event_filter 就是一个自定义的触发条件 @trainer.on(Events.EPOCH_STARTED(event_filter=event_filter)) def epoch_start(engine): print('EPOCH 开始前 event_filter') # 第 once epoch 结束后执行一次,后续不再执行 @trainer.on(Events.EPOCH_COMPLETED(once=3)) def epoch_start(engine): print('EPOCH 结束后 once=3') @trainer.on(Events.ITERATION_STARTED) def iteration_start(engine): print('ITERATION 开始前') # 每 every 个 epoch 后执行该事件函数 @trainer.on(Events.ITERATION_COMPLETED(every=2)) def iteration_start(engine): print('ITERATION 结束后 every=2') if __name__ == '__main__': # 我们这里为了对事件机制举例就不传入实际的数据 # 此时,需要设置 epoch_length 值 # 否则 Engine 不知道每个 epoch 循环多少次 trainer.run(max_epochs=9, epoch_length=4)
- 每个输出:”ITERATION 结束后” 之前会执行 2 次 “ITERATION 开始前”,表示每迭代 2 次执行一次该事件
- 在第 3 个 epoch 结束输出了 “EPOCH 结束后 once=3” ,此后再也没有输出,表示该事件只会触发一次
- 在第 [1, 3, 5 ,7] epoch 输出 “EPOCH 开始前 event_filter”,表示自定义的事件在合适的时间发挥了作用
训练开始前 EPOCH 开始前 event_filter ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 EPOCH 开始前 event_filter ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 EPOCH 结束后 once=3 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 EPOCH 开始前 event_filter ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 EPOCH 开始前 event_filter ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 ITERATION 开始前 ITERATION 开始前 ITERATION 结束后 every=2 训练结束后
下面演示函数的方式:
if __name__ == '__main__': # 初始化训练器 trainer = Engine(train_step) # 第一个参数为事件名称 # 第二个参数为事件函数 # 第三个参数为传递给事件函数的参数 trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine, data: print(data), [10, 20, 30]) # 我们这里为了对事件机制举例就不传入实际的数据 # 此时,需要设置 epoch_length 值 # 否则 Engine 不知道每个 epoch 循环多少次 trainer.run(max_epochs=9, epoch_length=4)
3. State
在 Engine 中存储了一些有用的信息,在我们训练过程中是需要的,下面为示例代码:
from ignite.engine import Engine from ignite.engine import Events def train_step(engine, batch_data): return {'y_pred': [1, 0], 'y_true': [0, 1], 'loss': 0.1314} def iteration_start(engine): print('最大训练轮数:', trainer.state.max_epochs) print('当前训练轮数:', trainer.state.epoch) print('当前迭代次数:', trainer.state.iteration) print('轮次迭代次数:', trainer.state.epoch_length) print('当前迭代输出:', trainer.state.output) print('-' * 64) if __name__ == '__main__': trainer = Engine(train_step) trainer.add_event_handler(Events.ITERATION_COMPLETED, iteration_start) trainer.run(max_epochs=2, epoch_length=2)
程序输出结果:
最大训练轮数: 2 当前训练轮数: 1 当前迭代次数: 1 轮次迭代次数: 2 当前迭代输出: {'y_pred': [1, 0], 'y_true': [0, 1], 'loss': 0.1314} ---------------------------------------------------------------- 最大训练轮数: 2 当前训练轮数: 1 当前迭代次数: 2 轮次迭代次数: 2 当前迭代输出: {'y_pred': [1, 0], 'y_true': [0, 1], 'loss': 0.1314} ---------------------------------------------------------------- 最大训练轮数: 2 当前训练轮数: 2 当前迭代次数: 3 轮次迭代次数: 2 当前迭代输出: {'y_pred': [1, 0], 'y_true': [0, 1], 'loss': 0.1314} ---------------------------------------------------------------- 最大训练轮数: 2 当前训练轮数: 2 当前迭代次数: 4 轮次迭代次数: 2 当前迭代输出: {'y_pred': [1, 0], 'y_true': [0, 1], 'loss': 0.1314} ----------------------------------------------------------------
4. Metrics
我们可以在 trainer 的 EPOCH_COMPLETED 事件函数中,来 run evaluator 来获得评估结果,并打印。
from ignite.metrics import Loss from ignite.metrics import Accuracy from ignite.engine import Engine from ignite.engine import Events import torch import torch.nn as nn def evaluation_step(engine, batch_data): # model.eval() # with torch.no_grad(): # pass y_pred = torch.tensor([[0.2, 0.8], [0.6, 0.4]]) y_true = torch.tensor([1, 1]) # 注意: 需要返回 tensor 类型 return y_pred, y_true def print_metric_result(engine): print('y_pred:', engine.state.output) print('y_loss:', engine.state.metrics['loss']) print('y_accu:', engine.state.metrics['accu']) if __name__ == '__main__': evaluator = Engine(evaluation_step) # 构建评估指标 loss = Loss(nn.CrossEntropyLoss()) accu = Accuracy() # 将评估指标添加到评估器中 # 第二个参数获取评估结果时,通过 engine.state.metric['accuracy'] 来获取结果 # 注意: evaluation_step 负责返回 Accuracy 需要的数据 loss.attach(evaluator, 'loss') accu.attach(evaluator, 'accu') evaluator.add_event_handler(Events.EPOCH_COMPLETED, print_metric_result) evaluator.run(epoch_length=2)
程序输出结果:
y_pred: (tensor([[0.2000, 0.8000], [0.6000, 0.4000]]), tensor([1, 1])) y_loss: 0.6178134083747864 y_accu: 0.5