使用 4 隐层 Bert 模型

我们经常使用 Bert 的预训练模型 bert-base-chinese 作为基础模型来实现下游的一些文本理解 NLP 任务。原始的 bert-base-chinese 使用 12 层的隐藏层网络,参数量巨大,本案例使用 4 层的 Bert 模型来解决下游分类问题。

由于,bert-base-chinese 具有 12 层的预训练参数,而我们的只有 4 个隐藏层,无法直接加载 bert-base-chinese 预训练模型,所以我们采用手动的方式只提取 12 隐藏层中的 5、7、9、11 层参数来作为 4 层模型的隐层参数,当然我们可以使用 bert-base-chinese 的 1、2、3、4 或者 8、9、10、11 隐层参数都是可以的。

同理,我们这里只是举例使用 4 隐层 Bert 模型,也可以使用 2 层、6 层,注意力头的话可以使用默认的 12、也可以使用 8 等等,都是可以的。

import torch
from transformers import BertForSequenceClassification
from transformers import BertTokenizer
from ignite.engine import Engine
from ignite.engine import Events
from datasets import load_from_disk
from datasets import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
from ignite.handlers import Checkpoint
from ignite.contrib.handlers import PiecewiseLinear
from ignite.contrib.handlers import ProgressBar
import pickle
from ignite.metrics import Accuracy
from ignite.metrics import ClassificationReport
from transformers import BertConfig
import sys
import time
import collections
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_step(engine, batch_data):
    sentences, labels = batch_data['content'], batch_data['label']
    params = {'return_tensors': 'pt', 'padding': 'longest'}
    inputs = engine.tokenizer(sentences, **params)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    output = engine.model(**inputs, labels=labels.to(device))
    engine.optimizer.zero_grad()
    output.loss.backward()
    engine.optimizer.step()

    return output.loss.item()


def on_train_epoch_started(engine):
    engine.epoch_loss = 0
    engine.progressbar = tqdm(range(engine.state.epoch_length), bar_format='{desc} |{bar}| [{elapsed}<{remaining}]')

def on_train_iteration_completed(engine):
    engine.epoch_loss += engine.state.output
    epoch_desc = f'训练 [{engine.state.epoch}/{engine.state.max_epochs}] '
    iteration_desc = f'[{engine.state.iteration}/{engine.state.epoch_length}] '
    batch_loss_desc = f'loss {engine.state.output:.4f}'
    engine.progressbar.set_description(epoch_desc + iteration_desc + batch_loss_desc)
    engine.progressbar.update()

def on_train_epoch_complated(engine):
    engine.total_losses.append(engine.epoch_loss)

    epoch_desc = f'训练 [{engine.state.epoch}/{engine.state.max_epochs}] '
    iteration_desc = f'[{engine.state.iteration}/{engine.state.epoch_length}] '
    epoch_loss_desc = f'loss {engine.epoch_loss:.4f}'
    engine.progressbar.set_description(epoch_desc + iteration_desc + epoch_loss_desc)
    engine.progressbar.close()

    engine.evaluator.run(engine.valid_data)

def on_train_completed(engine):
    pickle.dump(engine.total_losses, open('model/loss.pkl', 'wb'))
    print('训练完毕!')


def on_evaluate_epoch_started(engine):
    engine.progressbar = tqdm(range(engine.state.epoch_length), bar_format='{desc} |{bar}| [{elapsed}<{remaining}]')

def on_evaluate_iteration_complated(engine):
    epoch_desc = f'评估 [{engine.trainer.state.epoch}/{engine.trainer.state.max_epochs}] '
    iteration_desc = f'[{engine.state.iteration}/{engine.state.epoch_length}] '
    accuracy_desc = 'acc  0.0000'
    engine.progressbar.set_description(epoch_desc + iteration_desc + accuracy_desc)
    engine.progressbar.update()

def on_evaluate_epoch_completed(engine):
    epoch_desc = f'评估 [{engine.trainer.state.epoch}/{engine.trainer.state.max_epochs}] '
    iteration_desc = f'[{engine.state.iteration}/{engine.state.epoch_length}] '
    accuracy_desc = f'Acc {engine.state.metrics["accuracy"]:.4f}'
    engine.progressbar.set_description(epoch_desc + iteration_desc + accuracy_desc)
    engine.progressbar.close()


def train():

    trainer = Engine(train_step)
    trainer.total_losses = []
    trainer.epoch_loss = 0
    max_epochs = 30
    train_data = load_from_disk('data/senti-data')['train']
    num_labels = len(set(train_data['label']))
    train_data = DataLoader(train_data, batch_size=4, shuffle=True)
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    trainer.tokenizer = tokenizer

    # 使用 bert-base-chinese 第 5、7、9、11 层参数初始化
    params = torch.load('bert-base-chinese/pytorch_model.bin')
    keep_layers = [5, 7, 9, 11]
    state_dict = collections.OrderedDict()
    layer_prefix = 'bert.encoder.layer.'
    for name, value in params.items():
        if name.startswith(layer_prefix):
            for index, layer_index in enumerate(keep_layers):
                if name.startswith(layer_prefix + str(layer_index)):
                    state_dict[name.replace(str(layer_index), str(index))] = value
                    continue
            continue
        state_dict[name] = value
    config = BertConfig()
    config.num_hidden_layers = 4
    config.vocab_size = tokenizer.vocab_size
    config.num_labels = num_labels
    model = BertForSequenceClassification.from_pretrained(None, state_dict=state_dict, config=config)
    model = model.to(device)
    model.train()

    trainer.model = model
    trainer.add_event_handler(Events.EPOCH_STARTED, on_train_epoch_started)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, on_train_iteration_completed)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, on_train_epoch_complated)
    trainer.add_event_handler(Events.COMPLETED, on_train_completed)
    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    trainer.optimizer = optimizer
    milestones_values = [(0, 5e-5), (max_epochs * len(train_data), 0)]
    scheduler = PiecewiseLinear(optimizer, param_name='lr', milestones_values=milestones_values)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
    evaluator = Engine(test_step)
    Accuracy(device=device).attach(evaluator, 'accuracy')
    trainer.valid_data = DataLoader(load_from_disk('data/senti-data')['valid'], batch_size=8, shuffle=True)
    evaluator.add_event_handler(Events.EPOCH_STARTED, on_evaluate_epoch_started)
    evaluator.add_event_handler(Events.ITERATION_COMPLETED, on_evaluate_iteration_complated)
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, on_evaluate_epoch_completed)
    trainer.evaluator = evaluator
    evaluator.model = model
    evaluator.tokenizer = tokenizer
    evaluator.trainer = trainer
    # evaluator.run(trainer.valid_data)
    param = {'to_save': {'model': model}, 'save_handler': 'model/base4', 'n_saved': 2, 'score_function': lambda engine: engine.state.metrics['accuracy']}
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, Checkpoint(**param))
    trainer.run(train_data, max_epochs=max_epochs)


def test_step(engine, batch_data):
    engine.model.eval()
    with torch.no_grad():
        sentences, y_true = batch_data['content'], batch_data['label']
        params = {'return_tensors': 'pt', 'padding': 'longest'}
        inputs = engine.tokenizer(sentences, **params)
        inputs = {key: value.to(device) for key, value in inputs.items()}
        output = engine.model(**inputs)

    return output.logits, y_true.to(device)


def evaluate():

    evaluator = Engine(test_step)
    test_data = load_from_disk('data/senti-data')['test']
    num_labels = len(set(test_data['label']))
    test_data = DataLoader(test_data, batch_size=8, shuffle=True)
    config = BertConfig.from_pretrained('bert-base-chinese')
    config.num_labels = num_labels
    config.num_hidden_layers = 4
    model = BertForSequenceClassification(config=config)
    # model.load_state_dict(torch.load('model/evaluate/model_-31.9246.pt'))
    model.load_state_dict(torch.load('model/base4/model_0.5481.pt'))

    model = model.to(device)
    evaluator.model = model
    # 分词构建
    evaluator.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    # ClassificationReport(output_dict=True, device=device).attach(evaluator, 'class_report')
    Accuracy(device=device).attach(evaluator, 'accuracy')
    ProgressBar().attach(evaluator)
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: print(engine.state.metrics['accuracy']))

    evaluator.run(test_data)


if __name__ == '__main__':
    train()
    # evaluate()

如果希望使用自己的语料从零训练 4 层的 Bert 模型用于下游任务,则只需将下面的代码进行更改:

    # 使用 bert-base-chinese 第 5、7、9、11 层参数初始化
    params = torch.load('bert-base-chinese/pytorch_model.bin')
    keep_layers = [5, 7, 9, 11]
    state_dict = collections.OrderedDict()
    layer_prefix = 'bert.encoder.layer.'
    for name, value in params.items():
        if name.startswith(layer_prefix):
            for index, layer_index in enumerate(keep_layers):
                if name.startswith(layer_prefix + str(layer_index)):
                    state_dict[name.replace(str(layer_index), str(index))] = value
                    continue
            continue
        state_dict[name] = value
    config = BertConfig()
    config.num_hidden_layers = 4
    config.vocab_size = tokenizer.vocab_size
    config.num_labels = num_labels
    model = BertForSequenceClassification.from_pretrained(None, state_dict=state_dict, config=config)
    model = model.to(device)
    model.train()

更改为:

    config = BertConfig()
    config.num_hidden_layers = 4
    config.vocab_size = tokenizer.vocab_size
    config.num_labels = num_labels
    model = BertForSequenceClassification(config=config)
    model = model.to(device)

使用 4 万的商品数据,对基于 bert-base-chinese 微调的 4 层模型和从零训练的 4 层模型对比来看,前者效果要更好一些,收敛也很快。前者 Acc 为 0.9149,而后者 0.88 左右。

未经允许不得转载:一亩三分地 » 使用 4 隐层 Bert 模型