ELECTRA

我们知道 BERT 模型是应用 MLM 预训练任务得来的,而 ELECTRA(Efficiently Learning an Encoder that Classifies Token Replacements Accurately)模型则是使用 RTD(Replaced Token Detection)预训练任务得来的。ELECTRA 在使用鞥更少的参数和数据的情况下,也能够带来非常好的效果。

Paper 链接:https://openreview.net/pdf?id=r1xMH1BtvB

1. ELECTRA 计算过程

我截取了原 Paper 中的图来说明 ELECTRA 模型的计算过程,如下图所示:

BERT 的 MLM 预训练任务会对随机选择 15% 的 Token, 然后在这 15% 中的 80% 随机替换 [MASK],10% 保持原 Token, 10% 随机替换成其他 Token。

ELECTRA 则有些不同,它会选择部分的 Token 替换为 [MASK],但是替换成什么会由 Genertor 模型进行预测,用预测得到的 Token 替换 [MASK],即:Generator 计算之后会得到一个新的序列,这个序列其实是 Generator 根据带 [MASK] 的输入计算得出的 last_hidden_state 最后一层的隐藏状态。

接下来,将 last_hidden_state 输入到 Discriminator 中,Discriminator 则判断输入中的每一个 Token 是否被修改过,所以对应每个输入 Token 的值都是一个 logit。例如我们输入:”[CLS] 我 是 谁 ? [SEP]”,则输出的结果为:

[ 0.1326,  0.0408, -0.0277,  0.0670, -0.0679, -0.0209]

我们接下来可以通过 BCEWithLogitsLoss 来判断每个 Token 的损失,并进行反向传播,学习参数了。在上面的过程中,Generator 的职责有点像生成对抗网络(GAN)中的生成器,用来产生一个序列。Discriminator 则对应了判别器,判断生成器是否对输入进行了修改。

我在想,通过 RTD 模型想要学到什么能力呢?

  1. 当判别器预测错误时,判别器认为生成器对词的表示不够好,导致判别器无法区分,编码器必须重新调整对词的表示;
  2. 当判别器预测正确时,判别器认为生成器对词的表示还不错,不需要剧烈的调整表示;

也就是说,ELECTRA 希望啥时候判别器能够更好的区分词的表示了,那么生成器的词向量也就变得更加好了。

2. Transformers 源码理解

transformers 库可以通过下面的命令进行安装:

pip install transformers

在 transformers 库中对 ELECTRA 进行了实现。我们以 ElectraForPreTraining 实现来理解下 ELECTRA 模型的计算过程:

class ElectraForPreTraining(ElectraPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # 生成器
        self.electra = ElectraModel(config)
        # 判别器
        self.discriminator_predictions = ElectraDiscriminatorPredictions(config)
        # Initialize weights and apply final processing
        self.post_init()

从实现代码可以看到,ELECTRA 模型包含了 2 个部分,ElectraModel 部分再预训练时,则对包含 [MASK] 的输入进行编码,该部分的实现代码如下:

 class ElectraModel(ElectraPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # 输入进行词嵌入
        self.embeddings = ElectraEmbeddings(config)

        # 如果词嵌入维度和隐藏层输出维度不同,则线性变换到隐藏层维度
        if config.embedding_size != config.hidden_size:
            self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)

        # 默认包含了 12 个 Transformer 隐藏层
        self.encoder = ElectraEncoder(config)
        self.config = config
        # Initialize weights and apply final processing
        self.post_init()

接下来,看下 ElectraModel 类前向计算函数:

   def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None
        # ... 省略部分参数
    ) 
        # ... 省略部分代码

        # 先进行词嵌入
        hidden_states = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )

        # 对词嵌入向量进行维度变换
        if hasattr(self, "embeddings_project"):
            hidden_states = self.embeddings_project(hidden_states)

        # 送入 12 个 Transformer 进行多头自主力计算
        hidden_states = self.encoder(
            hidden_states,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        return hidden_states

ElectraDiscriminatorPredictions,将 ElectraModel 中的每个 Token 对应的 last_hidden_state 转换为一个 Logit,我们看下 ElectraDiscriminatorPredictions 代码的计算:

class ElectraDiscriminatorPredictions(nn.Module):
    """Prediction module for the discriminator, made up of two dense layers."""

    def __init__(self, config):
        super().__init__()

        # 对输入进行线性变换
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 将输入映射为单个 logit 值
        self.dense_prediction = nn.Linear(config.hidden_size, 1)
        self.config = config

    def forward(self, discriminator_hidden_states):
        hidden_states = self.dense(discriminator_hidden_states)
        # 默认使用 gelu 激活函数
        hidden_states = get_activation(self.config.hidden_act)(hidden_states)
        logits = self.dense_prediction(hidden_states).squeeze(-1)

        return logits

3. ELECTRA 下游任务

虽然在 ELECTRA 模型中出现了生成器和判别器两部分,但是我们在进行下有任务时只需要使用生成器就可以了,这是因为生成器是用来生成对输入的表示,而判别器在训练或者微调阶段,对生成器表示的一种引导、监督,为的是能够获得更好的输入表示。

在 transformers 库中提供了如下下游任务的封装:

  1. ElectraForPreTraining 封装 RTD 预训练任务;
  2. ElectraForSequenceClassification 用于 Sequence 级别的分类问题;
  3. ElectraForTokenClassification 用于 Token 级别的分类问题;

下面是用于 Sequence 级别分类任务的 Head 源代码,我们可以看到 ElectraClassificationHead 中对 ElectraModel 输出结果中的 [CLS] 对应的输出进行了分类,这个和 BERT 是一样的。

class ElectraClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = get_activation("gelu")(x)  # although BERT uses tanh here, it seems Electra authors used gelu here
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

下面是用于 Token 级别分类任务的代码实现,我们可以看到 ElectraForTokenClassification 直接对 ElectraModel 模型的输出的 last_hidden_state 进行了分类。

class ElectraForTokenClassification(ElectraPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.electra = ElectraModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        # Initialize weights and apply final processing
        self.post_init()
未经允许不得转载:一亩三分地 » ELECTRA