我们使用的是 TNEWS 数据集,该数据集来自今日头条的新闻版块,共提取了15个类别的新闻,包括旅游,教育,金融,军事等。数据量:训练集(53,360),验证集(10,000),测试集(10,000),例子:
{“label”: “102”, “label_des”: “news_entertainment”, “sentence”: “江疏影甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物”}
每一条数据有三个属性,从前往后分别是 分类ID,分类名称,新闻字符串(仅含标题)。
关于知识蒸馏的论文:https://arxiv.org/pdf/1903.12136.pdf
数据集链接:https://github.com/CLUEbenchmark/CLUE
实验主要包含以下几个方面:
- 微调 bert-base-chinese 模型 ,并对验证集进行评估;
- 从零开始训练 bi-lstm 模型,并对验证集进行评估;
- 学生 bi-lstm 学习老师的 bert 模型知识,并对验证集进行评估;
- 在 CLUE 提交测试集预测结果,得到模型评分。
1. 训练数据处理
数据集目录结构:
|-- README.txt |-- dev.json |-- labels.json |-- test.json |-- test1.0.json `-- train.json
数据集截图:
数据处理这部分就做了这么几件事:
- sentence 标准化
- label 转换为从 0 开始连续数字标签
- 以 datasets.Dataset 格式存储
示例代码:
from pyhanlp import JClass from datasets import Dataset import json import pandas as pd from datasets import DatasetDict from collections import Counter def load_data(): # 处理标签 label_mappding = {} with open('tnews/labels.json') as file: for index, line in enumerate(file): line = json.loads(line) label_mappding[line['label']] = {'label': index, 'label_desc': line['label_desc']} # 清洗文本 normalizer = JClass('com.hankcs.hanlp.dictionary.other.CharTable') def clear_text(text): text = normalizer.convert(text) return text # 训练集 train_data = {'label': [], 'sentence': []} vocab = [] with open('tnews/train.json') as file: for line in file: line = json.loads(line) label = label_mappding[line['label']]['label'] sentence = clear_text(line['sentence']) train_data['label'].append(label) train_data['sentence'].append(sentence) vocab.extend(sentence) train_data = Dataset.from_dict(train_data) # 验证集 valid_data = {'label': [], 'sentence': []} with open('tnews/dev.json') as file: for line in file: line = json.loads(line) valid_data['label'].append(label_mappding[line['label']]['label']) valid_data['sentence'].append(clear_text(line['sentence'])) valid_data = Dataset.from_dict(valid_data) # 构建词表 word_freq = Counter(vocab) unique_vocab = [word for word in set(vocab) if word_freq[word] > 2] unique_vocab.insert(0, '[UNK]') unique_vocab.insert(0, '[PAD]') open('data/vocab.txt', 'w').write('\n'.join(unique_vocab)) # 存储标签映射 pd.DataFrame(label_mappding).transpose().to_csv('data/label.csv') # 存储数据集 DatasetDict({'train': train_data, 'valid': valid_data}).save_to_disk('data') if __name__ == '__main__': load_data()
程序执行结束之后,产生的文件如下:
data ├── dataset_dict.json ├── label.csv ├── train │ ├── dataset.arrow │ ├── dataset_info.json │ └── state.json ├── valid │ ├── dataset.arrow │ ├── dataset_info.json │ └── state.json └── vocab.txt
2. 微调 Bert 模型
import torch import torch.nn as nn from transformers import BertTokenizer from transformers import BertForSequenceClassification from datasets import Dataset from torch.utils.data import DataLoader import torch.optim as optim from tqdm import tqdm from torchmetrics import Accuracy from torch.optim.lr_scheduler import LinearLR from datasets import load_from_disk # 初始化分词器 tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') # 定义计算设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_data(): # 读取数据集 data = load_from_disk('data') train_data, valid_data = data['train'], data['valid'] def collate_fn(batch_data): batch_inputs = [] batch_labels = [] for data in batch_data: batch_inputs.append(data['sentence']) batch_labels.append(data['label']) batch_inputs = tokenizer(batch_inputs, padding=True, return_tensors='pt') batch_inputs = { key: value.to(device) for key, value in batch_inputs.items() } batch_labels = torch.tensor(batch_labels, device=device) return batch_inputs, batch_labels params = {'batch_size': 16, 'collate_fn': collate_fn} train_loader = DataLoader(train_data, **params, shuffle=True) valid_loader = DataLoader(valid_data, **params, shuffle=False) return train_loader, valid_loader def train_model(): train_loader, valid_loader = load_data() estimator = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=15) estimator = estimator.to(device) optimizer = optim.AdamW(estimator.parameters(), lr=1e-5) scheduler = LinearLR(optimizer=optimizer, start_factor=0.3, end_factor=1, total_iters=5000) num_epoch = 10 metric = Accuracy().to(device) accumulation_steps = 0 for epoch_idx in range(num_epoch): progress = tqdm(range(len(train_loader))) total_loss = 0.0 for batch_inputs, batch_labels in train_loader: outputs = estimator(**batch_inputs, labels=batch_labels) optimizer.zero_grad() outputs.loss.backward() optimizer.step() scheduler.step() progress.set_description('epoch %2d iter loss %.2f' % (epoch_idx + 1, outputs.loss.item())) progress.update() total_loss += (outputs.loss.item() * len(batch_labels)) progress.set_description('epoch %2d total loss %.2f' % (epoch_idx + 1, total_loss)) progress.close() with torch.no_grad(): progress = tqdm(range(len(valid_loader))) for batch_inputs, batch_labels in valid_loader: outputs = estimator(**batch_inputs) metric.update(outputs.logits, batch_labels) progress.set_description('epoch %2d valid acc 0.000' % (epoch_idx + 1)) progress.update() acc = metric.compute() progress.set_description('epoch %2d valid acc %.3f' % (epoch_idx + 1, acc)) progress.close() # 模型保存 model_name = 'model1/{}-{:.4f}'.format(epoch_idx + 1, metric.compute()) estimator.save_pretrained(model_name) tokenizer.save_pretrained(model_name) if __name__ == '__main__': train_model()
训练结束之后,获得的模型如下:
1-0.559 2-0.562 3-0.564 4-0.565 5-0.565 6-0.564 7-0.563 8-0.563 9-0.563
前面的数字表示 epoch,后面的数字表示该模型在验证集上的 Acc.
3. 训练 Bi-LSTM 模型
import torch from transformers import BertTokenizer import torch.nn as nn from datasets import Dataset from torch.utils.data import DataLoader import torch.optim as optim from torch.optim.lr_scheduler import LinearLR from torchmetrics import Accuracy from tqdm import tqdm import numpy as np from torch.nn.utils.rnn import pack_padded_sequence from torch.nn.utils.rnn import pad_sequence from datasets import load_from_disk import os # 初始化分词器 tokenizer = BertTokenizer(vocab_file='data/vocab.txt') # 定义计算设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_data(): # 读取数据集 data = load_from_disk('data') train_data, valid_data = data['train'], data['valid'] def collate_fn(batch_data): batch_inputs = [] batch_labels = [] for data in batch_data: batch_inputs.append(data['sentence']) batch_labels.append(data['label']) batch_inputs = tokenizer(batch_inputs, padding=False, return_token_type_ids=False, return_special_tokens_mask=False, add_special_tokens=False)['input_ids'] # 获得批次序列长度 batch_length = [len(inputs) for inputs in batch_inputs] # 根据长度序列排序 sorted_index = np.argsort(-np.array(batch_length)) # 对输入和标签、长度重新排序 sorted_inputs, sorted_labels, sorted_length = [], [], [] for index in sorted_index: sorted_inputs.append(torch.tensor(batch_inputs[index], device=device)) sorted_labels.append(batch_labels[index]) sorted_length.append(batch_length[index]) # 对 pad_sequence 进行 0 填充,并转换为张量 sorted_inputs = pad_sequence(sorted_inputs, batch_first=True) sorted_labels = torch.tensor(sorted_labels, device=device) return sorted_inputs, sorted_labels, sorted_length params = {'batch_size': 16, 'collate_fn': collate_fn} train_loader = DataLoader(train_data, **params, shuffle=True) valid_loader = DataLoader(valid_data, **params, shuffle=False) return train_loader, valid_loader class SequneceClassification(nn.Module): def __init__(self): super(SequneceClassification, self).__init__() self.embedds = nn.Embedding(num_embeddings=tokenizer.vocab_size, embedding_dim=512) self.encoder = nn.LSTM(input_size=512, hidden_size=512, batch_first=True, bidirectional=True) self.outputs = nn.Linear(in_features=512 * 2, out_features=15) def forward(self, inputs, lengths=None): inputs = self.embedds(inputs) if lengths is not None: inputs = pack_padded_sequence(inputs, lengths=lengths, batch_first=True) # inputs: 各个 Token 的编码 # cn: 细胞状态 # hn: 最后一个时间步隐藏状态 inputs, (hn, cn) = self.encoder(inputs) inputs = torch.tanh(hn) inputs = inputs.transpose(0, 1) inputs = inputs.reshape(inputs.shape[0], -1) inputs = self.outputs(inputs) return inputs def train_model(): train_loader, valid_loader = load_data() estimator = SequneceClassification() estimator = estimator.to(device) optimizer = optim.AdamW(estimator.parameters(), lr=1e-4) criterion = nn.CrossEntropyLoss() scheduler = LinearLR(optimizer=optimizer, start_factor=0.3, end_factor=1, total_iters=1000) num_epoch = 50 metric = Accuracy().to(device) accumulation_steps = 0 for epoch_idx in range(num_epoch): progress = tqdm(range(len(train_loader))) total_loss = 0.0 for sorted_inputs, sorted_labels, sorted_length in train_loader: outputs = estimator(sorted_inputs, sorted_length) optimizer.zero_grad() loss = criterion(outputs.squeeze(), sorted_labels) loss.backward() optimizer.step() scheduler.step() progress.set_description('epoch %2d iter loss %.2f' % (epoch_idx + 1, loss.item())) progress.update() total_loss += (loss.item() * len(sorted_labels)) progress.set_description('epoch %2d total loss %.2f' % (epoch_idx + 1, total_loss)) progress.close() with torch.no_grad(): progress = tqdm(range(len(valid_loader))) for sorted_inputs, sorted_labels, sorted_length in valid_loader: outputs = estimator(sorted_inputs, sorted_length) metric.update(outputs.squeeze(), sorted_labels) progress.set_description('epoch %2d valid acc 0.000' % (epoch_idx + 1)) progress.update() acc = metric.compute() progress.set_description('epoch %2d valid acc %.3f' % (epoch_idx + 1, acc)) progress.close() # 模型保存 model_name = 'model2/{}-{:.4f}'.format(epoch_idx + 1, metric.compute()) if not os.path.exists(model_name): os.mkdir(model_name) torch.save(estimator.state_dict(), model_name + '/bilstm.bin') tokenizer.save_pretrained(model_name) if __name__ == '__main__': train_model()
训练结束之后,产生的所有模型如下:
10-0.5016 14-0.5011 19-0.5009 23-0.5008 28-0.5009 32-0.5009 37-0.5011 41-0.5010 46-0.5004 5-0.4994 1-0.4767 15-0.5012 20-0.5009 24-0.5007 29-0.5010 33-0.5009 38-0.5011 42-0.5009 47-0.5004 6-0.4997 11-0.5014 16-0.5010 2-0.4873 25-0.5007 30-0.5010 34-0.5009 39-0.5011 43-0.5008 48-0.5003 7-0.5013 12-0.5009 17-0.5006 21-0.5008 26-0.5007 3-0.4941 35-0.5010 40-0.5010 44-0.5006 49-0.5002 8-0.5019 13-0.5010 18-0.5008 22-0.5007 27-0.5008 31-0.5010 36-0.5011 4-0.4972 45-0.5004 50-0.5000 9-0.5014
最好的模型应该是第 8 epoch 训练的结果 50.19% 的 Acc。
4. BiLSTM 学习 Bert 的知识
这一步实验非常简单,步骤如下:
- 先将 Bert 对训练集的输出 logits 进行存储;
- BiLSTM 对训练集的输出 logits 和 Bert 输出的 logits 做 MSE 损失;
- 输出训练过程中,BiLSTM 对验证集的 Acc.
示例代码:
import torch import torch.nn as nn from transformers import BertTokenizer from transformers import BertForSequenceClassification from datasets import Dataset from torch.utils.data import DataLoader from tqdm import tqdm from datasets import load_from_disk import torch.optim as optim import numpy as np from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pack_padded_sequence from torchmetrics import Accuracy import os # 定义计算设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_teacher(): # 初始化分词器 tokenizer = BertTokenizer.from_pretrained('model1/5-0.565') # 读取数据集 data = load_from_disk('data') train_data, valid_data = data['train'], data['valid'] def collate_fn(batch_data): batch_inputs = [data['sentence'] for data in batch_data] batch_inputs = tokenizer(batch_inputs, padding=True, return_tensors='pt') batch_inputs = { key: value.to(device) for key, value in batch_inputs.items() } return batch_inputs, batch_data params = {'batch_size': 16, 'collate_fn': collate_fn} train_loader = DataLoader(train_data, **params, shuffle=False) return train_loader def teacher_do(): train_loader = load_teacher() estimator = BertForSequenceClassification.from_pretrained('model1/5-0.565', num_labels=15) estimator = estimator.to(device) estimator.eval() progress = tqdm(range(len(train_loader))) teacher_knowledge = {'label': [], 'sentence': []} for batch_inputs, batch_data in train_loader: with torch.no_grad(): outputs = estimator(**batch_inputs).logits for logit, sample in zip(outputs.cpu().numpy().tolist(), batch_data): teacher_knowledge['label'].append(logit) teacher_knowledge['sentence'].append(sample['sentence']) progress.update() progress.close() # 存储教师知识 Dataset.from_dict(teacher_knowledge).save_to_disk('knowledge') class SequneceClassification(nn.Module): def __init__(self): super(SequneceClassification, self).__init__() self.embedds = nn.Embedding(num_embeddings=tokenizer.vocab_size, embedding_dim=512) self.encoder = nn.LSTM(input_size=512, hidden_size=512, batch_first=True, bidirectional=True) self.outputs = nn.Linear(in_features=512 * 2, out_features=15) def forward(self, inputs, lengths=None): inputs = self.embedds(inputs) if lengths is not None: inputs = pack_padded_sequence(inputs, lengths=lengths, batch_first=True) # inputs: 各个 Token 的编码 # cn: 细胞状态 # hn: 最后一个时间步隐藏状态 inputs, (hn, cn) = self.encoder(inputs) inputs = torch.tanh(hn) inputs = inputs.transpose(0, 1) inputs = inputs.reshape(inputs.shape[0], -1) inputs = self.outputs(inputs) return inputs # 初始化分词器 tokenizer = BertTokenizer(vocab_file='data/vocab.txt') def load_student(): # 读取数据集 valid_data = load_from_disk('data')['valid'] train_data = load_from_disk('knowledge') def collate_fn(batch_data): batch_inputs = [] batch_labels = [] for data in batch_data: batch_inputs.append(data['sentence']) batch_labels.append(data['label']) batch_inputs = tokenizer(batch_inputs, padding=False, return_token_type_ids=False, return_special_tokens_mask=False, add_special_tokens=False)['input_ids'] # 获得批次序列长度 batch_length = [len(inputs) for inputs in batch_inputs] # 根据长度序列排序 sorted_index = np.argsort(-np.array(batch_length)) # 对输入和标签、长度重新排序 sorted_inputs, sorted_labels, sorted_length = [], [], [] for index in sorted_index: sorted_inputs.append(torch.tensor(batch_inputs[index], device=device)) sorted_labels.append(batch_labels[index]) sorted_length.append(batch_length[index]) # 对 pad_sequence 进行 0 填充,并转换为张量 sorted_inputs = pad_sequence(sorted_inputs, batch_first=True) sorted_labels = torch.tensor(sorted_labels, device=device) return sorted_inputs, sorted_labels, sorted_length params = {'batch_size': 16, 'collate_fn': collate_fn} train_loader = DataLoader(train_data, **params, shuffle=True) valid_loader = DataLoader(valid_data, **params, shuffle=False) return train_loader, valid_loader def student_to(): train_loader, valid_loader = load_student() estimator = SequneceClassification().to(device) optimizer = optim.Adam(estimator.parameters(), lr=1e-4) criterion = nn.MSELoss() metric = Accuracy().to(device) num_epoch = 20 for epoch_idx in range(num_epoch): progress = tqdm(range(len(train_loader))) total_loss = 0.0 for sorted_inputs, sorted_labels, sorted_length in train_loader: outputs = estimator(sorted_inputs, sorted_length) optimizer.zero_grad() loss = criterion(outputs.squeeze(), sorted_labels) loss.backward() optimizer.step() progress.set_description('epoch %2d iter loss %.2f' % (epoch_idx + 1, loss.item())) progress.update() total_loss += (loss.item() * len(sorted_labels)) progress.set_description('epoch %2d total loss %.2f' % (epoch_idx + 1, total_loss)) progress.close() with torch.no_grad(): progress = tqdm(range(len(valid_loader))) for sorted_inputs, sorted_labels, sorted_length in valid_loader: outputs = estimator(sorted_inputs, sorted_length) metric.update(outputs.squeeze(), sorted_labels) progress.set_description('epoch %2d valid acc 0.000' % (epoch_idx + 1)) progress.update() acc = metric.compute() progress.set_description('epoch %2d valid acc %.3f' % (epoch_idx + 1, acc)) progress.close() # 模型保存 model_name = 'model3/{}-{:.4f}'.format(epoch_idx + 1, metric.compute()) if not os.path.exists(model_name): os.mkdir(model_name) torch.save(estimator.state_dict(), model_name + '/bilstm.bin') tokenizer.save_pretrained(model_name) if __name__ == '__main__': student_to()
训练过程输出的结果:
epoch 1 total loss 197160.26: 100%|████████| 3335/3335 [00:44<00:00, 75.35it/s] epoch 1 valid acc 0.477: 100%|██████████████| 625/625 [00:04<00:00, 133.95it/s] epoch 2 total loss 147970.64: 100%|████████| 3335/3335 [00:43<00:00, 75.85it/s] epoch 2 valid acc 0.494: 100%|██████████████| 625/625 [00:04<00:00, 134.77it/s] epoch 3 total loss 129521.85: 100%|████████| 3335/3335 [00:44<00:00, 75.79it/s] epoch 3 valid acc 0.501: 100%|██████████████| 625/625 [00:04<00:00, 134.88it/s] epoch 4 total loss 114689.55: 100%|████████| 3335/3335 [00:44<00:00, 75.78it/s] epoch 4 valid acc 0.507: 100%|██████████████| 625/625 [00:04<00:00, 134.75it/s] epoch 5 total loss 100872.29: 100%|████████| 3335/3335 [00:43<00:00, 75.94it/s] epoch 5 valid acc 0.510: 100%|██████████████| 625/625 [00:04<00:00, 134.70it/s] epoch 6 total loss 87712.62: 100%|█████████| 3335/3335 [00:43<00:00, 75.87it/s] epoch 6 valid acc 0.513: 100%|██████████████| 625/625 [00:04<00:00, 135.19it/s] epoch 7 total loss 75687.72: 100%|█████████| 3335/3335 [00:44<00:00, 75.77it/s] epoch 7 valid acc 0.515: 100%|██████████████| 625/625 [00:04<00:00, 134.78it/s] epoch 8 total loss 64895.03: 100%|█████████| 3335/3335 [00:44<00:00, 75.76it/s] epoch 8 valid acc 0.517: 100%|██████████████| 625/625 [00:04<00:00, 135.04it/s] epoch 9 total loss 55662.15: 100%|█████████| 3335/3335 [00:44<00:00, 75.67it/s] epoch 9 valid acc 0.518: 100%|██████████████| 625/625 [00:04<00:00, 134.84it/s] epoch 10 total loss 47618.56: 100%|█████████| 3335/3335 [00:43<00:00, 75.80it/s] epoch 10 valid acc 0.519: 100%|██████████████| 625/625 [00:04<00:00, 134.93it/s] epoch 11 total loss 40930.74: 100%|█████████| 3335/3335 [00:44<00:00, 75.69it/s] epoch 11 valid acc 0.519: 100%|██████████████| 625/625 [00:04<00:00, 134.59it/s] epoch 12 total loss 35400.10: 100%|█████████| 3335/3335 [00:44<00:00, 75.66it/s] epoch 12 valid acc 0.520: 100%|██████████████| 625/625 [00:04<00:00, 134.85it/s] epoch 13 total loss 30833.56: 100%|█████████| 3335/3335 [00:44<00:00, 75.73it/s] epoch 13 valid acc 0.521: 100%|██████████████| 625/625 [00:04<00:00, 134.48it/s] epoch 14 total loss 26831.78: 100%|█████████| 3335/3335 [00:44<00:00, 75.71it/s] epoch 14 valid acc 0.521: 100%|██████████████| 625/625 [00:04<00:00, 134.95it/s] epoch 15 total loss 23764.05: 100%|█████████| 3335/3335 [00:44<00:00, 75.78it/s] epoch 15 valid acc 0.521: 100%|██████████████| 625/625 [00:04<00:00, 135.02it/s] epoch 16 total loss 20811.81: 100%|█████████| 3335/3335 [00:44<00:00, 75.77it/s] epoch 16 valid acc 0.522: 100%|██████████████| 625/625 [00:04<00:00, 134.85it/s] epoch 17 total loss 18524.72: 100%|█████████| 3335/3335 [00:44<00:00, 75.71it/s] epoch 17 valid acc 0.522: 100%|██████████████| 625/625 [00:04<00:00, 134.95it/s] epoch 18 total loss 16605.46: 100%|█████████| 3335/3335 [00:44<00:00, 75.69it/s] epoch 18 valid acc 0.523: 100%|██████████████| 625/625 [00:04<00:00, 134.85it/s] epoch 19 total loss 14851.84: 100%|█████████| 3335/3335 [00:44<00:00, 75.72it/s] epoch 19 valid acc 0.523: 100%|██████████████| 625/625 [00:04<00:00, 134.75it/s] epoch 20 total loss 13164.29: 100%|█████████| 3335/3335 [00:44<00:00, 75.75it/s] epoch 20 valid acc 0.523: 100%|██████████████| 625/625 [00:04<00:00, 134.79it/s]
BiLSTM 在验证集上的 Acc 由直接训练得到的 50.19% 上升到 52.3%,看起来还是效果还是很明显的。我觉得,调参或者增加 epoch 数量,还能继续增加 BiLSTM 的 Acc。
我们这里训练 BiLSTM 时使用的是训练集,论文中还提到,可以对原数据进行增强,使得学生模型可以从教师模型中学到更多的暗知识(Dark Knowledge),论文中给出了增强方法。