高效的预训练 ELECTRA 语言模型

ELECTRA(Efficiently Learning an Encoder that Classifies Token Replacements Accurately)是一种用于自然语言处理(NLP)的预训练模型,旨在提升模型训练效率和性能。ELECTRA 模型由 Google Research 提出,相比传统的 BERT 模型,它采用了一种不同的预训练策略。

Paper:https://openreview.net/pdf?id=r1xMH1BtvB

1. 训练目标

ELECTRA的核心思想是通过 生成替换与判别任务(Replaced Token Detection,RTD) 来进行训练。具体来说,ELECTRA 的训练过程由以下两部分组成:

  • 生成器(Generator):通过最大似然估计训练,目的是生成与上下文相关的替代标记
  • 判别器(Discriminator):通过二元分类任务训练,学习区分真实标记和生成器生成的替代标记

在上图中, 将文本 the chef cooked the meal 中的某些词随机替换为 [mask],然后输入到 Generator ,并由 Generator 将其中的 [MASK] 预测(替换)为其他的词。例如,将第一个 [mask] 替换为原词 the,而第二个 [mask] 则替换为另一个词 ate,这样就得到了一条训练样本。Discriminator 的任务是对生成器生成的文本进行判断,预测 token 是是 original 还是 replaced

ELECTRA 的损失函数由生成器损失和判判别器损失组成。生成器的损失函数是基于Masked Language Model 的损失函数。具体来说,生成器的任务是预测被掩盖的 token,因此其损失函数为 MLM 损失函数:

其中,\( M \) 是掩蔽的词位置,\( x_{mask} \)​ 是掩蔽后的词,\( \theta \) 是生成器的参数,\( P_{\theta}(x_{i}|x_{mask}) \) 是生成器预测的概率。

生成器生成的 假词 将会被标记为负样本,原始词 则作为正样本。判别器的损失就是预测每个token 被替换过的概率。

其中,\( D \) 是输入数据,\( \hat{x} \) 是被替换的 token,\( \phi \) 是判别器的参数,\( P_{\phi}(real|x) \) 和 \(P_{\phi}(fake|\hat{x}) \) 分别是判别器对 真token假token 的预测概率。

ELECTRA 的总损失函数是生成器损失和判别器损失的加权和:

其中,\( \lambda \) 是一个超参数,控制生成器损失和判别器损失之间的相对重要性。

通过上面损失的计算过程,我们发现在 Electra 中,并没有采用对抗性训练(Adversarial Training),虽然 Electra 的结构看起来类似于 GAN(生成对抗网络),但它实际上不是一个真正的对抗式训练模型。

在 ​GAN 中,生成器和判别器是 ​对抗性 的,生成器试图生成更逼真的样本,而判别器试图区分真实样本和生成样本。生成器的训练目标是 ​欺骗判别器,而判别器的训练目标是 ​不被欺骗

在 ​Electra 中,生成器的训练目标是 ​预测被 mask 的 token,而不是欺骗判别器。生成器的损失函数是基于其预测的 token 与真实 token 的交叉熵损失,与判别器的任务无关。判别器的训练目标是 ​区分原始 token 和替代 token,而不是与生成器进行对抗。

2. 下游任务

Electra 采用了一个较小的生成器(Generator)和一个更强大的判别器(Discriminator)。生成器的能力受限,它不会完全准确地预测真实 token,而是生成接近但带有一定偏差的 token,从而构造用于训练的伪数据。判别器的任务是学习如何区分真实 token 和生成器生成的伪 token,这一过程中,判别器会学习到更丰富的 token 表征能力。

换句话说,生成器的主要作用是辅助训练,为判别器提供学习样本,而判别器则通过学习 token 之间的细微差异来获取更好的表示能力。在下游任务中,我们通常基于判别器学习到的文本表征进行微调,以适应不同的任务需求。

英文语料预训练:
https://huggingface.co/google/electra-small-discriminator
https://huggingface.co/google/electra-base-discriminator
中文语料预训练:
https://huggingface.co/hfl/chinese-electra-small-discriminator
https://huggingface.co/hfl/chinese-electra-base-discriminator
https://huggingface.co/hfl/chinese-electra-180g-small-discriminator
https://huggingface.co/hfl/chinese-electra-180g-base-discriminator
https://huggingface.co/hfl/chinese-electra-180g-small-ex-discriminator

接下来,使用 chinese-electra-small-discriminator 来微调一个简单的文本分类任务。

from transformers import ElectraForSequenceClassification
from transformers import ElectraTokenizer
import torch
from transformers import Trainer
from transformers import TrainingArguments
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score



def demo():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    modelpath = 'chinese-electra-small-discriminator'
    tokenizer = ElectraTokenizer.from_pretrained(modelpath)
    estimator = ElectraForSequenceClassification.from_pretrained(modelpath, num_labels=2).to(device)
    samples = pd.read_csv('comments.csv').dropna()
    x_train, x_valid = train_test_split(samples, test_size=0.2, shuffle=True, stratify=samples['label'])
    x_train, x_valid = x_train.to_numpy(), x_valid.to_numpy()

    def get_dataset(data):
        all_data = []
        for label, input in data:
            input = tokenizer.encode_plus(input,
                                          truncation=True,
                                          max_length=512,
                                          return_tensors=None,
                                          return_token_type_ids=False,
                                          return_attention_mask=False)
            input['labels'] = torch.tensor(label)
            all_data.append(input)
        return all_data

    x_train = get_dataset(x_train)
    x_valid = get_dataset(x_valid)

    training_args = TrainingArguments(
        output_dir='results',
        eval_strategy='epoch',
        save_strategy='epoch',
        logging_strategy='no',
        learning_rate=5e-5,
        optim='adamw_torch',
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        num_train_epochs=10,
        weight_decay=0.01,
        disable_tqdm=True,
    )

    def do_metric(eval_pred):
        # # 获取预测的 logits 和真实标签
        logits, labels = eval_pred
        predictions = logits.argmax(axis=-1)
        acc = accuracy_score(labels, predictions)
        return {'acc': acc}

    trainer = Trainer(model=estimator,
                      args=training_args,
                      train_dataset=x_train,
                      eval_dataset=x_valid,
                      compute_metrics=do_metric,
                      processing_class=tokenizer)
    trainer.train()

    tokenizer.save_pretrained('saved_model')
    estimator.save_pretrained('saved_model')


if __name__ == '__main__':
    demo()
未经允许不得转载:一亩三分地 » 高效的预训练 ELECTRA 语言模型