数据处理主要是加在语料、构建词典、以及将数据集转换为索引表示。我们这里会删除句子长度超过 505 的句子。由于构建词典时,是根据训练集数据构建的,所以在对测试集进行编码时,可能会出现 oov 问题,我们直接用 UNK 来表示,这部分的操作可以交给 BertTokenizer 来完成。数据集链接:
https://www.aliyundrive.com/s/HkNk5zog6gi 提取码: 60oq
import pandas as pd import torch import pickle from datasets import Dataset from datasets import DatasetDict
1. 读取语料
读取语料就是将在不同 txt 文件中的文本和标签合并到同一个 csv 文件中,方便后面的处理。处理过程中,也过滤掉了长度大于 505 的数据。函数执行之后,会产生 01-训练集.csv 和 02-测试集.csv 两个文件。
def load_corpus():
train_path = ['msra/train/sentences.txt', 'msra/train/tags.txt']
valid_path = ['msra/valid/sentences.txt', 'msra/valid/tags.txt']
data_path = [train_path, valid_path]
# 1. 读取训练集数据
data_inputs, data_labels = [], []
for x_path, y_path in data_path:
for inputs, labels in zip(open(x_path), open(y_path)):
inputs = inputs.split()
labels = labels.split()
if len(inputs) > 505:
continue
if len(inputs) != len(labels):
continue
data_labels.append(' '.join(labels))
data_inputs.append(' '.join(inputs))
# 存储训练集数据
train_data = pd.DataFrame()
train_data['data_inputs'] = data_inputs
train_data['data_labels'] = data_labels
train_data.to_csv('data/01-训练集.csv')
print('训练集数据量:', len(train_data))
# 2. 读取测试集数据
test_input_path = 'msra/test/sentences.txt'
test_label_path = 'msra/test/tags.txt'
data_inputs, data_labels = [], []
for inputs, labels in zip(open(test_input_path), open(test_label_path)):
inputs = inputs.split()
labels = labels.split()
if len(inputs) > 505:
continue
if len(inputs) != len(labels):
continue
data_labels.append(' '.join(labels))
data_inputs.append(' '.join(inputs))
# 存储测试集数据
test_data = pd.DataFrame()
test_data['data_inputs'] = data_inputs
test_data['data_labels'] = data_labels
test_data.to_csv('data/02-测试集.csv')
print('测试集数据量:', len(test_data))
2. 构建词典
我们后面会使用 BertTokenizer,这里就根据 01-训练集.csv 文件构建一个词表文件 bilstm_crf_vocab.txt,文件内容简要如下:

完整实现代码如下:
def build_vocab():
data_inputs = pd.read_csv('data/01-训练集.csv', usecols=['data_inputs',]).values
words = []
for data_input in data_inputs:
data_input = data_input[0].split()
words.extend(data_input)
unique_words = list(set(words))
unique_words.insert(0, '[UNK]')
unique_words.insert(0, '[PAD]')
# 将字写入到 data/vocab.txt 词典文件中
with open('data/bilstm_crf_vocab.txt', 'w') as file:
for word in unique_words:
file.write(word + '\n')
3. 标签编码
将数据集中的标签以数字索引来表示。最后将 DatasetDict 存储到 /bilstm_crf_data 目录下。
def encode_label():
labels = ['O', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER', 'B-LOC', 'I-LOC']
label_to_index = {label: index for index, label in enumerate(labels)}
# 将 csv 数据转换成 Dataset 类型
train_data = pd.read_csv('data/01-训练集.csv')
valid_data = pd.read_csv('data/02-测试集.csv')
train_data = Dataset.from_pandas(train_data)
valid_data = Dataset.from_pandas(valid_data)
corpus_data = DatasetDict({'train': train_data, 'valid': valid_data})
# 将标签数据转换为索引表示
def data_handler(data_labels, data_inputs):
data_label_ids = []
for labels in data_labels:
label_ids = []
for label in labels.split():
label_ids.append(label_to_index[label])
data_label_ids.append(label_ids)
return {'data_labels': data_label_ids, 'data_inputs': data_inputs}
corpus_data = corpus_data.map(data_handler, input_columns=['data_labels', 'data_inputs'], batched=True)
# 数据存储
corpus_data.save_to_disk('data/bilstm_crf_data')
if __name__ == '__main__':
load_corpus()
build_vocab()
encode_label()
程序执行输出:
训练集数据量: 44968 测试集数据量: 3438

冀公网安备13050302001966号