我们知道 BERT 模型是应用 MLM 预训练任务得来的,而 ELECTRA(Efficiently Learning an Encoder that Classifies Token Replacements Accurately)模型则是使用 RTD(Replaced Token Detection)预训练任务得来的。ELECTRA 在使用鞥更少的参数和数据的情况下,也能够带来非常好的效果。
Paper 链接:https://openreview.net/pdf?id=r1xMH1BtvB
1. ELECTRA 计算过程
我截取了原 Paper 中的图来说明 ELECTRA 模型的计算过程,如下图所示:

BERT 的 MLM 预训练任务会对随机选择 15% 的 Token, 然后在这 15% 中的 80% 随机替换 [MASK],10% 保持原 Token, 10% 随机替换成其他 Token。
ELECTRA 则有些不同,它会选择部分的 Token 替换为 [MASK],但是替换成什么会由 Genertor 模型进行预测,用预测得到的 Token 替换 [MASK],即:Generator 计算之后会得到一个新的序列,这个序列其实是 Generator 根据带 [MASK] 的输入计算得出的 last_hidden_state 最后一层的隐藏状态。
接下来,将 last_hidden_state 输入到 Discriminator 中,Discriminator 则判断输入中的每一个 Token 是否被修改过,所以对应每个输入 Token 的值都是一个 logit。例如我们输入:”[CLS] 我 是 谁 ? [SEP]”,则输出的结果为:
[ 0.1326, 0.0408, -0.0277, 0.0670, -0.0679, -0.0209]
我们接下来可以通过 BCEWithLogitsLoss 来判断每个 Token 的损失,并进行反向传播,学习参数了。在上面的过程中,Generator 的职责有点像生成对抗网络(GAN)中的生成器,用来产生一个序列。Discriminator 则对应了判别器,判断生成器是否对输入进行了修改。
我在想,通过 RTD 模型想要学到什么能力呢?
- 当判别器预测错误时,判别器认为生成器对词的表示不够好,导致判别器无法区分,编码器必须重新调整对词的表示;
- 当判别器预测正确时,判别器认为生成器对词的表示还不错,不需要剧烈的调整表示;
也就是说,ELECTRA 希望啥时候判别器能够更好的区分词的表示了,那么生成器的词向量也就变得更加好了。
2. Transformers 源码理解
transformers 库可以通过下面的命令进行安装:
pip install transformers
在 transformers 库中对 ELECTRA 进行了实现。我们以 ElectraForPreTraining 实现来理解下 ELECTRA 模型的计算过程:
class ElectraForPreTraining(ElectraPreTrainedModel): def __init__(self, config): super().__init__(config) # 生成器 self.electra = ElectraModel(config) # 判别器 self.discriminator_predictions = ElectraDiscriminatorPredictions(config) # Initialize weights and apply final processing self.post_init()
从实现代码可以看到,ELECTRA 模型包含了 2 个部分,ElectraModel 部分再预训练时,则对包含 [MASK] 的输入进行编码,该部分的实现代码如下:
class ElectraModel(ElectraPreTrainedModel): def __init__(self, config): super().__init__(config) # 输入进行词嵌入 self.embeddings = ElectraEmbeddings(config) # 如果词嵌入维度和隐藏层输出维度不同,则线性变换到隐藏层维度 if config.embedding_size != config.hidden_size: self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size) # 默认包含了 12 个 Transformer 隐藏层 self.encoder = ElectraEncoder(config) self.config = config # Initialize weights and apply final processing self.post_init()
接下来,看下 ElectraModel 类前向计算函数:
def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None # ... 省略部分参数 ) # ... 省略部分代码 # 先进行词嵌入 hidden_states = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) # 对词嵌入向量进行维度变换 if hasattr(self, "embeddings_project"): hidden_states = self.embeddings_project(hidden_states) # 送入 12 个 Transformer 进行多头自主力计算 hidden_states = self.encoder( hidden_states, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) return hidden_states
ElectraDiscriminatorPredictions,将 ElectraModel 中的每个 Token 对应的 last_hidden_state 转换为一个 Logit,我们看下 ElectraDiscriminatorPredictions 代码的计算:
class ElectraDiscriminatorPredictions(nn.Module): """Prediction module for the discriminator, made up of two dense layers.""" def __init__(self, config): super().__init__() # 对输入进行线性变换 self.dense = nn.Linear(config.hidden_size, config.hidden_size) # 将输入映射为单个 logit 值 self.dense_prediction = nn.Linear(config.hidden_size, 1) self.config = config def forward(self, discriminator_hidden_states): hidden_states = self.dense(discriminator_hidden_states) # 默认使用 gelu 激活函数 hidden_states = get_activation(self.config.hidden_act)(hidden_states) logits = self.dense_prediction(hidden_states).squeeze(-1) return logits
3. ELECTRA 下游任务
虽然在 ELECTRA 模型中出现了生成器和判别器两部分,但是我们在进行下有任务时只需要使用生成器就可以了,这是因为生成器是用来生成对输入的表示,而判别器在训练或者微调阶段,对生成器表示的一种引导、监督,为的是能够获得更好的输入表示。
在 transformers 库中提供了如下下游任务的封装:
- ElectraForPreTraining 封装 RTD 预训练任务;
- ElectraForSequenceClassification 用于 Sequence 级别的分类问题;
- ElectraForTokenClassification 用于 Token 级别的分类问题;
- …
下面是用于 Sequence 级别分类任务的 Head 源代码,我们可以看到 ElectraClassificationHead 中对 ElectraModel 输出结果中的 [CLS] 对应的输出进行了分类,这个和 BERT 是一样的。
class ElectraClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, features, **kwargs): x = features[:, 0, :] # take <s> token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) x = get_activation("gelu")(x) # although BERT uses tanh here, it seems Electra authors used gelu here x = self.dropout(x) x = self.out_proj(x) return x
下面是用于 Token 级别分类任务的代码实现,我们可以看到 ElectraForTokenClassification 直接对 ElectraModel 模型的输出的 last_hidden_state 进行了分类。
class ElectraForTokenClassification(ElectraPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.electra = ElectraModel(config) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init()