import torch import torch.nn as nn from transformers import BertTokenizer from datasets import load_from_disk from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pack_padded_sequence from ignite.engine import Engine from ignite.engine import Events from torch.utils.data import DataLoader from ignite.contrib.handlers import ProgressBar from ignite.handlers import Checkpoint import numpy as np import torch.optim as optim import torch.nn.functional as F from sklearn.metrics import precision_recall_fscore_support from sklearn.metrics import accuracy_score from ignite.contrib.handlers import PiecewiseLinear # datasets.disable_progress_bar() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class SentimentClassification(nn.Module): def __init__(self, vocab_size): super(SentimentClassification, self).__init__() self.word_embeding = nn.Embedding(vocab_size, 128) self.inputs_encode = nn.LSTM(128, 256, num_layers=1, bidirectional=True, batch_first=True) self.output_logits = nn.Linear(256 * 2, 1) self.dropout = nn.Dropout(p=0.2) def forward(self, inputs, length=None): inputs = self.word_embeding(inputs) inputs = self.dropout(inputs) if length is not None: inputs = pack_padded_sequence(inputs, length, batch_first=True) output, (hn, cn) = self.inputs_encode(inputs) # 将 hn 维度从 (2, 2, 256) 修改为 (2, 512) hn = hn.transpose(0, 1) last_hidden_state = hn.reshape(hn.shape[0], -1) logits = self.output_logits(last_hidden_state) return logits, last_hidden_state def train(): def data_collate(batch_data): batch_labels, batch_inputs = [], [] for data in batch_data: batch_labels.append(data['label']) batch_inputs.append(data['content']) batch_inputs = [tokenizer.encode(inputs) for inputs in batch_inputs] batch_length = [len(input_ids) for input_ids in batch_inputs] sorted_index = np.argsort(-np.array(batch_length)) # 排序输入和标签 sorted_inputs, sorted_labels, sroted_length = [], [], [] for index in sorted_index: sorted_inputs.append(batch_inputs[index]) sorted_labels.append(batch_labels[index]) sroted_length.append(batch_length[index]) # 填充0对齐长度 sorted_inputs = pad_sequence([torch.tensor(inputs) for inputs in sorted_inputs], batch_first=True).to(device) sorted_labels = torch.tensor(sorted_labels, device=device, dtype=torch.float32) return sorted_inputs, sorted_labels, sroted_length def train_step(engine, batch_data): inputs, labels, length = batch_data engine.model.train() logits, _ = engine.model(inputs, length) loss = trainer.criterion(logits.flatten(), labels) optimizer.zero_grad() loss.backward() optimizer.step() return {'batch_loss': loss.item() * len(labels), 'batch_iter': len(labels)} trainer = Engine(train_step) ProgressBar().attach(trainer) train_data = load_from_disk('data/senti-data')['train'] num_labels = len(set(train_data['label'])) train_data = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=data_collate) tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') model = SentimentClassification(vocab_size=tokenizer.vocab_size).to(device) criterion = nn.BCEWithLogitsLoss() optimizer = optim.AdamW(model.parameters(), lr=5e-5) max_epochs = 30 milestones_values = [(0, 5e-5), (max_epochs * len(train_data), 0)] scheduler = PiecewiseLinear(optimizer, param_name='lr', milestones_values=milestones_values) trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler) trainer.model = model trainer.tokenizer = tokenizer trainer.criterion = criterion trainer.optimizer = optimizer @trainer.on(Events.EPOCH_STARTED) def on_train_epoch_started(engine): engine.total_loss = 0.0 engine.total_iter = 0 @trainer.on(Events.ITERATION_COMPLETED) def on_train_iteration_complated(engine): engine.total_loss += engine.state.output['batch_loss'] engine.total_iter += engine.state.output['batch_iter'] @trainer.on(Events.EPOCH_COMPLETED) def on_train_epoch_completed(engine): print('loss: %.5f' % (engine.total_loss / engine.total_iter)) print('开始验证模型...') validator.run(validator.valid_data) @trainer.on(Events.COMPLETED) def on_train_completed(engine): print('开始评估模型...') tester.run(tester.test_data) # 验证模型 def evaluate_step(engine, batch_data): inputs, labels, length = batch_data engine.model.eval() with torch.no_grad(): logits, _ = engine.model(inputs, length) prediction = torch.where(torch.sigmoid(logits.flatten()) > 0.5, 1, 0) y_pred = prediction.cpu().numpy() y_true = labels.cpu().numpy() return {'y_pred': y_pred, 'y_true': y_true} validator = Engine(evaluate_step) validator.prev_acc = 0 valid_data = load_from_disk('data/senti-data')['valid'] validator.valid_data = DataLoader(valid_data, batch_size=16, collate_fn=data_collate) validator.model = model @validator.on(Events.EPOCH_STARTED) def on_valid_epoch_started(engine): engine.y_pred = [] engine.y_true = [] @validator.on(Events.ITERATION_COMPLETED) def on_valid_iteration_complated(engine): engine.y_pred.extend(engine.state.output['y_pred']) engine.y_true.extend(engine.state.output['y_true']) @validator.on(Events.COMPLETED) def on_valid_completed(engine): precision, recall, f_score, true_sum = \ precision_recall_fscore_support(engine.y_true, engine.y_pred) accuracy = accuracy_score(engine.y_true, engine.y_pred) print('查准率:', '%.5f\t%.5f' % (precision[0], precision[1])) print('查全率:', '%.5f\t%.5f' % (recall[0], recall[1])) print('综合率:', '%.5f\t%.5f' % (f_score[0], f_score[1])) print('样本量:', '%d\t%d' % (true_sum[0], true_sum[1])) print('准确率:', '%.5f' % accuracy) print('-' * 55) # if accuracy > validator.prev_acc: # validator.prev_acc = accuracy # torch.save(model.state_dict(), 'model/bilstm.bin') torch.save(model.state_dict(), 'model/bilstm.bin') # 评估模型 tester = Engine(evaluate_step) test_data = load_from_disk('data/senti-data')['test'] tester.test_data = DataLoader(test_data, batch_size=16, collate_fn=data_collate) test_model = SentimentClassification(vocab_size=tokenizer.vocab_size).to(device) test_model.load_state_dict(torch.load('model/bilstm.bin')) tester.model = test_model tester.add_event_handler(Events.STARTED, on_valid_epoch_started) tester.add_event_handler(Events.ITERATION_COMPLETED, on_valid_iteration_complated) @tester.on(Events.COMPLETED) def on_valid_completed(engine): precision, recall, f_score, true_sum = \ precision_recall_fscore_support(engine.y_true, engine.y_pred) accuracy = accuracy_score(engine.y_true, engine.y_pred) print('测试集评估结果:') print('查准率:', '%.5f\t%.5f' % (precision[0], precision[1])) print('查全率:', '%.5f\t%.5f' % (recall[0], recall[1])) print('综合率:', '%.5f\t%.5f' % (f_score[0], f_score[1])) print('样本量:', '%d\t%d' % (true_sum[0], true_sum[1])) print('准确率:', '%.5f' % accuracy) print('-' * 55) trainer.run(train_data, max_epochs=max_epochs) if __name__ == '__main__': train()
BiLSTM 情感分析(2)
未经允许不得转载:一亩三分地 » BiLSTM 情感分析(2)