Teacher Forcing

Teacher Forcing​ 是一种在训练序列生成模型(如循环神经网络 RNN、长短期记忆网络 LSTM、Transformer 等)时使用的技术。

1. 问题场景

Teacher Forcing 是一种用于训练序列生成模型的方法,它的核心思想是在训练过程中,使用真实的标签(ground truth)作为模型下一步的输入,而不是使用模型自身生成的输出。例如:

不使用 Teacher Forcing
模型生成第一个词 “I”,然后将 “I” 作为输入生成下一个词。如果模型生成的第一个词是 “You”,那么接下来的输入就是 “You”,这可能导致后续生成的序列完全偏离真实标签。

使用 Teacher Forcing
模型生成第一个词 “I”,但无论模型生成的是什么,下一个时间步的输入始终是真实的标签 “love”。这样,模型在训练过程中始终能够看到正确的上下文,从而更容易学习到正确的序列生成模式。

由于模型的每一步输入都来自真实数据,而非它自己的预测值,因此训练过程更加稳定,可以显著减少梯度消失问题,收敛速度更快。

如果使用模型自身的预测作为下一步输入,早期预测的错误可能会被不断放大,影响整个序列的生成质量。而 Teacher Forcing 避免了这种错误的级联效应。

2. 曝光偏差

尽管 Teacher Forcing 在训练时效果显著,但在实际推理(inference)阶段,模型无法再依赖真实数据,而是必须依赖自己的预测结果。这会导致 Exposure Bias(曝光偏差) 问题:

  • 训练与推理的不匹配(train-test discrepancy):训练时模型始终使用正确的输入,但测试时必须依赖自己的预测结果,可能导致性能下降。
  • 泛化能力受限:模型未学会如何从自身错误中恢复,因而在测试时容易失控,生成低质量的输出。

为了解决 Teacher Forcing 带来的问题,可以使用 Scheduled Sampling 或强化学习方法来缓解其影响。它的 核心思想 是在训练过程中,逐步减少使用真实目标值作为输入的概率,让模型逐步适应自身预测。

具体做法是引入一个 概率 p

  • 以概率 p 使用真实的作为输入(Teacher Forcing)
  • 以概率 1-p 使用模型的预测结果作为输入
  • 训练过程中逐步降低 p,让模型适应自身预测。

在实际应用中,是否使用 Teacher Forcing 需要权衡训练稳定性与泛化能力。

未经允许不得转载:一亩三分地 » Teacher Forcing
评论 (0)

4 + 8 =