SimCSE 对比学习训练句向量思路

SimCSE 提出一种基于 Dropout,把 Dropout 作为一种数据增强方法的无监督训练 sentence embedding 的方法。

Paper:https://aclanthology.org/2021.emnlp-main.552.pdf
Github:https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/text_matching/simcse

SimCSE 函数在 PaddleNLP 中有实现,我根据该实现改成了 PyTorch 的版本。forward 函数实现思路如下:

  1. query 和 title 是两个相同 batch 的输入
  2. 分别将 query 和 title 送入到 Bert 模型中得到每个 batch 中的句子表示,即:[CLS] 向量表示
  3. 将 query 中的每一个句子向量和 title 中的每一个句子向量分别计算余弦相似度,最终得到 [batch_size, batch_size] 的相似度矩阵
  4. 每一行表示某个句子和其他句子的相似度,我们这里再构建一个标签序列,假设 batch_size=3,则标签序列为 [0, 1, 2],即表示 query 中的第一个句子和 title 中的第一个句子计算相似,和其他句子计算不相似,其他句子也按照这么标签来计算损失
  5. 返回损失

1. 训练思路一

一开始在训练时,直接使用 SimCSE 最初的方式。即:直接将训练样本扔到 bert 模型中两次,并且 in batch negative。训练时,尝试使用预训练的 albert_chinese_tiny、bert-base-chinese、chinese-bert-wwm-ext,得到的训练效果并不是很好。这里的不好,是人工评估匹配的结果之后得到的结论。下面是训练代码:

模型实现代码:

import torch
import torch.nn as nn
import torch.nn.functional as F


class SimCSE(nn.Module):

    def __init__(self, dropout=0.1, margin=0.0, scale=20, ebdsize=256):
        """
        :param dropout: 丢弃率
        :param margin: 当两个样本的句子小于 margin 则认为是相似的
        :param scale: 放大相似度的值,帮助模型更好地区分不同的相似度级别,有利于模型收敛
        :param ebdsize: 输出的向量维度
        """

        super(SimCSE, self).__init__()
        # 初始化编码器对象
        # from transformers import BertModel
        # self.encoder = BertModel.from_pretrained('pretrained/bert-base-chinese')
        # self.encoder = BertModel.from_pretrained('pretrained/chinese-bert-wwm-ext')
        from transformers import AlbertModel
        self.encoder = AlbertModel.from_pretrained('pretrained/albert_chinese_tiny')
        # 设置随机丢弃率
        self.dropout = nn.Dropout(dropout)
        # 控制输出向量维度
        self.ebdsize = ebdsize
        self.outputs = nn.Linear(self.encoder.config.hidden_size, ebdsize)

        self.margin = margin
        self.sacle = scale

    def get_encoder_embedding(self, input_ids, attention_mask=None, with_pooler=False):

        # 输出的结果经过了池化
        sequence_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

        if with_pooler:
            cls_embedding = sequence_output.pooler_output
        else:
            cls_embedding = sequence_output.last_hidden_state[:, 0, :]

        cls_embedding = self.outputs(cls_embedding)
        cls_embedding = self.dropout(cls_embedding)
        # 对向量进行 normalize,即每个向量除以其二范数(单位向量),计算余弦相似度时,只需要进行点积计算即可
        cls_embedding = F.normalize(cls_embedding, p=2, dim=-1)

        return cls_embedding


    def save_pretrained(self, model_save_path):

        self.encoder.save_pretrained(model_save_path)
        torch.save(self.outputs.state_dict(), model_save_path + '/dim_reduce.pth')
        model_param = {'dropout': self.dropout, 'ebdsize': self.ebdsize, 'margin': self.margin, 'sacle': self.sacle}
        torch.save(model_param, model_save_path + '/model_param.pth')


    def from_pretrained(self, model_save_path):

        model_param = torch.load(model_save_path + '/model_param.pth')
        self.sacle = model_param['sacle']
        self.margin = model_param['margin']
        self.dropout = model_param['dropout']
        self.ebdsize = model_param['ebdsize']
        self.encoder.from_pretrained(model_save_path)
        self.outputs.load_state_dict(torch.load(model_save_path + '/dim_reduce.pth', map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))

        return self


    def forward(self, query_input_ids, title_input_ids, query_attention_mask=None, title_attention_mask=None):

        query_cls_embedding = self.get_encoder_embedding(query_input_ids, query_attention_mask)
        title_cls_embedding = self.get_encoder_embedding(title_input_ids, title_attention_mask)

        # 获得输入数据的计算设备
        device = query_cls_embedding.device

        # query_cls_embedding 和 title_cls_embedding 已经经过了标准化,此处直接相乘得到余弦相似度
        # 计算两两句子的相似度
        # torch.Size([3, 256])
        # torch.Size([256, 3])
        cosine_sim = torch.matmul(query_cls_embedding, title_cls_embedding.transpose(1, 0))

        # tensor([0., 0., 0.])
        margin_diag = torch.full(size=[query_cls_embedding.size(0)], fill_value=self.margin).to(device)

        # tensor([[0.8873, 0.8727, 0.8366],
        #         [0.8876, 0.8834, 0.9100],
        #         [0.9068, 0.9079, 0.8703]], grad_fn=<SubBackward0>)
        cosine_sim = cosine_sim - torch.diag(margin_diag)

        # 放大相似度,有利于模型收敛
        # tensor([[17.7461, 17.4537, 16.7329],
        #         [17.7512, 17.6680, 18.2001],
        #         [18.1369, 18.1573, 17.4062]], grad_fn=<MulBackward0>)
        cosine_sim *= self.sacle

        # 构建标签
        # tensor([0, 1, 2])
        labels = torch.arange(0, query_cls_embedding.size(0)).to(device)

        # 计算损失
        # tensor(1.2422, grad_fn=<NllLossBackward0>)
        loss = F.cross_entropy(input=cosine_sim, target=labels)


        return loss

训练过程使用的是 cMedQA 数据集,数据链接为:https://www.luge.ai/#/luge/dataDetail?id=70,只需要数据集中的 v2.0 中的 question.csv 和 answer.csv 文件,问题数量有 12 万,我这里只用前 1 万个问题进行词向量训练。下面是训练代码:

import torch
from simcse import SimCSE
import torch.optim as optim
import torch.nn as nn
import pandas as pd
from torch.utils.data import DataLoader
from functools import partial
from transformers import BertTokenizer
from tqdm import tqdm
import torch.nn.functional as F
import numpy as np
import random


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('训练设备:', device)


class DataSimCSE:

    def __init__(self):
        self._data = self._get_train()

    def _get_train(self):
        train = pd.read_csv('data/question.csv')['content']
        return train

    def __len__(self):
        return len(self._data)

    def __getitem__(self, index):
        return {'one': self._data[index], 'two': self._data[index]}


def collate_function(tokenizer, batch_data):

    ones, twos = [], []
    for data in batch_data:
        ones.append(data['one'])
        twos.append(data['two'])

    ones = tokenizer(ones, return_token_type_ids=False, padding='longest', return_tensors='pt')
    twos = tokenizer(twos, return_token_type_ids=False, padding='longest', return_tensors='pt')
    ones = {key: value.to(device) for key, value in ones.items()}
    twos = {key: value.to(device) for key, value in twos.items()}

    model_inputs = {}
    model_inputs['query_input_ids'] = ones['input_ids']
    model_inputs['title_input_ids'] = twos['input_ids']
    model_inputs['query_attention_mask'] = ones['attention_mask']
    model_inputs['title_attention_mask'] = twos['attention_mask']

    return model_inputs


def train_simcse():

    estimator = SimCSE().to(device)
    optimizer = optim.Adam(estimator.parameters(), lr=1e-5)
    tokenizer = BertTokenizer.from_pretrained('pretrained/bert-base-chinese')
    dataloadr = DataLoader(DataSimCSE(), shuffle=True, batch_size=8, collate_fn=lambda batch_data: collate_function(tokenizer, batch_data))
    epoch_num = 40

    for epoch in range(epoch_num):

        progress = tqdm(range(len(dataloadr)))
        epoch_loss = 0.0
        for model_inputs in dataloadr:
            loss = estimator(**model_inputs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            progress.set_description('epoch %2d %8.4f' % (epoch + 1, epoch_loss))
            progress.update()
        progress.close()

        model_save_path = 'model/epoch_%d_simcse_loss_%.4f' % (epoch + 1, epoch_loss)
        estimator.save_pretrained(model_save_path)
        tokenizer.save_pretrained(model_save_path)


if __name__ == '__main__':
    train_simcse()

2. 训练思路二

在看 SimCSE 时,发现作者进行了改进。改进思路是,原来输入到模型的句子都是等长的,这可能让模型被句子长度局限,所以作者使用 word repetition 的方式对句子长度进行了增强。实现思路,就是对 input_ids 随机复制,增强代码如下:

def word_repetition(input_ids, dup_rate=0.32):

    # 获得 batch_size 大小,以及
    batch_size = len(input_ids)
    repetitied_input_ids = []

    for batch_id in range(batch_size):
        cur_input_id = input_ids[batch_id]
        # 统计非0的句子长度
        actual_len = np.count_nonzero(cur_input_id)
        dup_word_index = []
        # 如果句子长度大于5才进行操作
        if actual_len > 5:
            dup_len = random.randint(a=0, b=max(2, int(dup_rate * actual_len)))
            # Skip cls and sep position
            dup_word_index = random.sample(list(range(1, actual_len - 1)), k=dup_len)

        r_input_id = []
        for idx, word_id in enumerate(cur_input_id):
            # 插入重复 Token
            if idx in dup_word_index:
                r_input_id.append(word_id)
            r_input_id.append(word_id)

        repetitied_input_ids.append(r_input_id)

    # 填充补齐 batch 序列长度
    repetitied_input_ids_mask = []
    batch_maxlen = max([len(ids) for ids in repetitied_input_ids])
    for batch_id in range(batch_size):
        after_dup_len = len(repetitied_input_ids[batch_id])
        pad_len = batch_maxlen - after_dup_len
        repetitied_input_ids[batch_id] += [0] * pad_len

        mask = np.ones((len(repetitied_input_ids[batch_id]), ), dtype=np.int32)
        mask[np.array(repetitied_input_ids[batch_id]) == 0] = 0

        repetitied_input_ids_mask.append(mask.tolist())


    repetitied_input_ids = torch.tensor(repetitied_input_ids, device=device)
    repetitied_input_ids_mask = torch.tensor(repetitied_input_ids_mask, device=device)

    return repetitied_input_ids, repetitied_input_ids_mask

训练代码变为:

class DataSimCSE:

    def __init__(self):
        self._data = self._get_train()

    def _get_train(self):
        train = pd.read_csv('data/question.csv')['content']
        return train

    def __len__(self):
        return len(self._data)

    def __getitem__(self, index):
        return {'one': self._data[index], 'two': self._data[index]}


# 对输入进行 word repetition 增强
def collate_function_aug(tokenizer, batch_data):

    ones, twos = [], []
    for data in batch_data:
        ones.append(data['one'])
        twos.append(data['two'])

    ones = tokenizer(ones, return_token_type_ids=False)['input_ids']
    twos = tokenizer(twos, return_token_type_ids=False)['input_ids']

    model_inputs = {}
    input_ids, attention_mask = word_repetition(ones)
    model_inputs['query_input_ids'] = input_ids
    model_inputs['query_attention_mask'] = attention_mask

    input_ids, attention_mask = word_repetition(twos)
    model_inputs['title_input_ids'] = input_ids
    model_inputs['title_attention_mask'] = attention_mask

    return model_inputs


def train_simcse():

    estimator = SimCSE().to(device)
    optimizer = optim.Adam(estimator.parameters(), lr=1e-5)
    tokenizer = BertTokenizer.from_pretrained('pretrained/bert-base-chinese')
    dataloadr = DataLoader(DataSimCSE(), shuffle=True, batch_size=8, collate_fn=lambda batch_data: collate_function_aug(tokenizer, batch_data))
    epoch_num = 20

    for epoch in range(epoch_num):

        progress = tqdm(range(len(dataloadr)))
        epoch_loss = 0.0
        for model_inputs in dataloadr:
            loss = estimator(**model_inputs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            progress.set_description('epoch %2d %8.4f' % (epoch + 1, epoch_loss))
            progress.update()
        progress.close()

        if epoch > 5:
            model_save_path = 'model_aug/epoch_%d_simcse_loss_%.4f' % (epoch + 1, epoch_loss)
            estimator.save_pretrained(model_save_path)
            tokenizer.save_pretrained(model_save_path)


if __name__ == '__main__':
    # test()
    train_simcse()

3. 训练思路三

这个增强的改进对于我的问题来说效果并不明显,我对训练过程也进行了一些改进。我在想,我希望模型训练的目标是,问题和问题之间能够明显区分,原来训练时,都是随机从数据集中选择一个 batch,这个 batch 中负样本可能很多都已经和正样本能够区分了,我们的重心是哪些不能明显区分的,把这些样本作为负样本会比较好一些。

所以,在构建 batch 的负样本时,改进如下:每个 epoch 结束,我就让模型生成所有的句子向量,并存储到 faiss 数据库中,针对每个问题,从数据库中选择最相似的 64 个样本,去除样本自身,也就是 63 个与当前样本最相似的样本作为负样本来进行训练。

这里训练过程,仍然使用对长度增强的策略。在思路二的基础上进行的改进。

用于根据上一轮的训练结果,重新构建正负样本的函数如下:

@torch.no_grad()
def generate_negative_samples(model_path=None, epoch=0, batch_size=5, start_init=False):

    if start_init:
        estimator = SimCSE().eval().to(device)
        tokenizer = BertTokenizer.from_pretrained('pretrained/albert_chinese_tiny')
    else:
        estimator = SimCSE().from_pretrained(model_path).eval().to(device)
        tokenizer = BertTokenizer.from_pretrained(model_path)

    questions = select_questions()

    # 对 batch 数据进行编码
    def collate_function(batch_data):
        question_index, question_input = [], []
        for index, input in batch_data:
            question_index.append(index)
            question_input.append(input)
        question_input = tokenizer(question_input, padding='longest', return_token_type_ids=False, return_tensors='pt')
        question_input = {key: value.to(device) for key, value in question_input.items()}
        return question_index, question_input

    dataloader = DataLoader(questions, batch_size=128, collate_fn=collate_function)
    progress = tqdm(range(len(dataloader)), desc='开始生成 epcoh=%d 向量' % epoch)

    qid_to_ebd = {}
    qids, ebds = [], []
    for bqid, inputs in dataloader:
        bebd = estimator.get_encoder_embedding(**inputs)
        qids.extend(bqid)
        ebds.append(bebd)
        for qid, ebd in zip(bqid, bebd):
            qid_to_ebd[qid] = ebd
        progress.update()
    progress.set_description('结束生成 epcoh=%d 向量' % epoch)
    progress.close()

    # 存储向量索引对象 [10000, 256]
    ebds = torch.concat(ebds, dim=0).cpu()
    database = faiss.IndexIDMap(faiss.IndexFlatIP(256))
    database.add_with_ids(ebds, qids)

    # 每个样本生成负样本
    _, search_ids = database.search(ebds, batch_size)
    questions = dict(questions)
    candidate_questions = []
    for qid, sqids in zip(qids, search_ids.tolist()):
        if qid in sqids:
            sqids.remove(qid)
        candidate_questions.append([questions[id] for id in [qid] + sqids[:batch_size-1]])

    return candidate_questions

训练函数代码如下:

def train_simcse():

    estimator = SimCSE().to(device)
    optimizer = optim.Adam(estimator.parameters(), lr=1e-5)
    tokenizer = BertTokenizer.from_pretrained('pretrained/bert-base-chinese')
    epoch_num = 20
    batch_size = 64
    dataset = select_questions()
    # 初始化负样本
    current_data = generate_negative_samples(batch_size=batch_size, start_init=True)
    current_indexes = list(range(len(current_data)))

    def collate_function(batch_data):

        ones = tokenizer(batch_data, return_token_type_ids=False)['input_ids']
        twos = tokenizer(batch_data, return_token_type_ids=False)['input_ids']

        model_inputs = {}
        input_ids, attention_mask = word_repetition(ones)
        model_inputs['query_input_ids'] = input_ids
        model_inputs['query_attention_mask'] = attention_mask

        input_ids, attention_mask = word_repetition(twos)
        model_inputs['title_input_ids'] = input_ids
        model_inputs['title_attention_mask'] = attention_mask

        return model_inputs

    for epoch in range(epoch_num):

        epoch_loss = 0.0
        random.shuffle(current_indexes)
        progress = tqdm(range(len(current_indexes)))
        for index, current_index in enumerate(current_indexes):

            model_inputs = collate_function(current_data[current_index])
            loss = estimator(**model_inputs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += (loss.item() * len(model_inputs['query_input_ids']))
            progress.set_description('epoch %2d %8.4f' % (epoch + 1, epoch_loss))
            progress.update()
        progress.close()

        model_save_path = 'finish/semantics/simcse/%2d_semantic_simcse_loss_%.4f' % (epoch + 1, epoch_loss)
        estimator.save_pretrained(model_save_path)
        tokenizer.save_pretrained(model_save_path)
        current_data = generate_negative_samples(model_path=model_save_path, epoch=epoch+1, batch_size=batch_size, start_init=False)
未经允许不得转载:一亩三分地 » SimCSE 对比学习训练句向量思路
评论 (0)

2 + 9 =