BiLSTM 情感分析(2)

import torch
import torch.nn as nn
from transformers import BertTokenizer
from datasets import load_from_disk
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from ignite.engine import Engine
from ignite.engine import Events
from torch.utils.data import DataLoader
from ignite.contrib.handlers import ProgressBar
from ignite.handlers import Checkpoint
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
from ignite.contrib.handlers import PiecewiseLinear

# datasets.disable_progress_bar()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class SentimentClassification(nn.Module):

    def __init__(self, vocab_size):
        super(SentimentClassification, self).__init__()
        self.word_embeding = nn.Embedding(vocab_size, 128)
        self.inputs_encode = nn.LSTM(128, 256, num_layers=1, bidirectional=True, batch_first=True)
        self.output_logits = nn.Linear(256 * 2, 1)
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, inputs, length=None):

        inputs = self.word_embeding(inputs)
        inputs = self.dropout(inputs)
        if length is not None:
            inputs = pack_padded_sequence(inputs, length, batch_first=True)
        output, (hn, cn) = self.inputs_encode(inputs)
        # 将 hn 维度从 (2, 2, 256) 修改为 (2, 512)
        hn = hn.transpose(0, 1)
        last_hidden_state = hn.reshape(hn.shape[0], -1)
        logits = self.output_logits(last_hidden_state)

        return logits, last_hidden_state


def train():

    def data_collate(batch_data):

        batch_labels, batch_inputs = [], []
        for data in batch_data:
            batch_labels.append(data['label'])
            batch_inputs.append(data['content'])

        batch_inputs = [tokenizer.encode(inputs) for inputs in batch_inputs]
        batch_length = [len(input_ids) for input_ids in batch_inputs]
        sorted_index = np.argsort(-np.array(batch_length))

        # 排序输入和标签
        sorted_inputs, sorted_labels, sroted_length = [], [], []
        for index in sorted_index:
            sorted_inputs.append(batch_inputs[index])
            sorted_labels.append(batch_labels[index])
            sroted_length.append(batch_length[index])

        # 填充0对齐长度
        sorted_inputs = pad_sequence([torch.tensor(inputs) for inputs in sorted_inputs], batch_first=True).to(device)
        sorted_labels = torch.tensor(sorted_labels, device=device, dtype=torch.float32)

        return sorted_inputs, sorted_labels, sroted_length

    def train_step(engine, batch_data):
        inputs, labels, length = batch_data
        engine.model.train()
        logits, _ = engine.model(inputs, length)
        loss = trainer.criterion(logits.flatten(), labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {'batch_loss': loss.item() * len(labels), 'batch_iter': len(labels)}

    trainer = Engine(train_step)
    ProgressBar().attach(trainer)
    train_data = load_from_disk('data/senti-data')['train']
    num_labels = len(set(train_data['label']))
    train_data = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=data_collate)
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    model = SentimentClassification(vocab_size=tokenizer.vocab_size).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    max_epochs = 30
    milestones_values = [(0, 5e-5), (max_epochs * len(train_data), 0)]
    scheduler = PiecewiseLinear(optimizer, param_name='lr', milestones_values=milestones_values)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)

    trainer.model = model
    trainer.tokenizer = tokenizer
    trainer.criterion = criterion
    trainer.optimizer = optimizer

    @trainer.on(Events.EPOCH_STARTED)
    def on_train_epoch_started(engine):
        engine.total_loss = 0.0
        engine.total_iter = 0

    @trainer.on(Events.ITERATION_COMPLETED)
    def on_train_iteration_complated(engine):
        engine.total_loss += engine.state.output['batch_loss']
        engine.total_iter += engine.state.output['batch_iter']

    @trainer.on(Events.EPOCH_COMPLETED)
    def on_train_epoch_completed(engine):
        print('loss: %.5f' % (engine.total_loss / engine.total_iter))
        print('开始验证模型...')
        validator.run(validator.valid_data)

    @trainer.on(Events.COMPLETED)
    def on_train_completed(engine):
        print('开始评估模型...')
        tester.run(tester.test_data)

    # 验证模型
    def evaluate_step(engine, batch_data):
        inputs, labels, length = batch_data
        engine.model.eval()
        with torch.no_grad():
            logits, _ = engine.model(inputs, length)
            prediction = torch.where(torch.sigmoid(logits.flatten()) > 0.5, 1, 0)

        y_pred = prediction.cpu().numpy()
        y_true = labels.cpu().numpy()

        return {'y_pred': y_pred, 'y_true': y_true}


    validator = Engine(evaluate_step)
    validator.prev_acc = 0
    valid_data = load_from_disk('data/senti-data')['valid']
    validator.valid_data = DataLoader(valid_data, batch_size=16, collate_fn=data_collate)
    validator.model = model

    @validator.on(Events.EPOCH_STARTED)
    def on_valid_epoch_started(engine):
        engine.y_pred = []
        engine.y_true = []

    @validator.on(Events.ITERATION_COMPLETED)
    def on_valid_iteration_complated(engine):
        engine.y_pred.extend(engine.state.output['y_pred'])
        engine.y_true.extend(engine.state.output['y_true'])

    @validator.on(Events.COMPLETED)
    def on_valid_completed(engine):
        precision, recall, f_score, true_sum = \
            precision_recall_fscore_support(engine.y_true, engine.y_pred)
        accuracy = accuracy_score(engine.y_true, engine.y_pred)

        print('查准率:', '%.5f\t%.5f' % (precision[0], precision[1]))
        print('查全率:', '%.5f\t%.5f' % (recall[0], recall[1]))
        print('综合率:', '%.5f\t%.5f' % (f_score[0], f_score[1]))
        print('样本量:', '%d\t%d' % (true_sum[0], true_sum[1]))
        print('准确率:', '%.5f' % accuracy)
        print('-' * 55)

        # if accuracy > validator.prev_acc:
        #     validator.prev_acc = accuracy
        #     torch.save(model.state_dict(), 'model/bilstm.bin')
        torch.save(model.state_dict(), 'model/bilstm.bin')

    # 评估模型
    tester = Engine(evaluate_step)
    test_data = load_from_disk('data/senti-data')['test']
    tester.test_data = DataLoader(test_data, batch_size=16, collate_fn=data_collate)
    test_model = SentimentClassification(vocab_size=tokenizer.vocab_size).to(device)
    test_model.load_state_dict(torch.load('model/bilstm.bin'))
    tester.model = test_model

    tester.add_event_handler(Events.STARTED, on_valid_epoch_started)
    tester.add_event_handler(Events.ITERATION_COMPLETED, on_valid_iteration_complated)

    @tester.on(Events.COMPLETED)
    def on_valid_completed(engine):
        precision, recall, f_score, true_sum = \
            precision_recall_fscore_support(engine.y_true, engine.y_pred)
        accuracy = accuracy_score(engine.y_true, engine.y_pred)

        print('测试集评估结果:')
        print('查准率:', '%.5f\t%.5f' % (precision[0], precision[1]))
        print('查全率:', '%.5f\t%.5f' % (recall[0], recall[1]))
        print('综合率:', '%.5f\t%.5f' % (f_score[0], f_score[1]))
        print('样本量:', '%d\t%d' % (true_sum[0], true_sum[1]))
        print('准确率:', '%.5f' % accuracy)
        print('-' * 55)

    trainer.run(train_data, max_epochs=max_epochs)


if __name__ == '__main__':
    train()
未经允许不得转载:一亩三分地 » BiLSTM 情感分析(2)