很多资料表明,通过文本数据增强也能够增强模型的分类性能。本篇文章总结几种文本数据增强的方法:
- 马尔科夫链文本增强
- 百度回译数据增强
- EDA 数据增强
1. 数据信息简单展示
接下来演示使用的数据为 json 格式,其共有 6 个类别,分别使用 0-5 的数字来表示,下面的代码对数据进行简单处理并可视化。
import pandas as pd from collections import Counter import matplotlib.pyplot as plt import seaborn as sns from pyhanlp import JClass def data_plot(data): _, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=100) plt.subplots_adjust(wspace=0.5,hspace=0.6) clases = Counter(data['label']) sns.barplot(x=list(clases.keys()), y=list(clases.values()), ax=axes[0]) axes[0].set_title('类别样本数量') axes[0].set_xlabel('class') axes[0].set_ylabel('Count') for x, y in clases.items(): axes[0].text(x, y+0.5, y, ha='center') length = data['sentence'].map(lambda x: len(x)) sns.histplot(x=length, ax=axes[1]) axes[1].set_title('句子长度分布') plt.show() def data_split(): data = pd.read_json('data/senti-data.json') data.columns = ['sentence', 'label'] normalizer = JClass('com.hankcs.hanlp.dictionary.other.CharTable') def clean_data(sentence): sentence = normalizer.convert(sentence) sentence = ''.join(sentence.split()) return sentence sentence = [clean_data(sentence) for sentence in data['sentence']] data['sentence'] = sentence data_plot(data) print(data.info()) data.to_csv('data/train.csv') if __name__ == '__main__': data_split()
RangeIndex: 40133 entries, 0 to 40132 Data columns (total 2 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 sentence 40133 non-null object 1 label 40133 non-null int64 dtypes: int64(1), object(1) memory usage: 627.2+ KB None
我们可以看到不同类别数据分布不均,我们通过一些方法去增加少数类别的样本,使得不同类别数据量都能差不多。这里可以再补充一下,再进行数据增强时,一般都是对训练集进行数据增强,测试集仍然使用原始未增强的数据集。我们这里就假设拿到的数据全部为训练集数据。
2. 马尔科夫链数据增强
import pandas as pd from collections import Counter import jieba import logging import random from tqdm import tqdm import matplotlib.pyplot as plt import seaborn as sns jieba.setLogLevel(logging.CRITICAL) def data_plot(data): _, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=100) plt.subplots_adjust(wspace=0.5,hspace=0.6) # 类别样本数量 clases = Counter(data['label']) sns.barplot(x=list(clases.keys()), y=list(clases.values()), ax=axes[0]) axes[0].set_title('类别样本数量') axes[0].set_xlabel('class') axes[0].set_ylabel('Count') for x, y in clases.items(): axes[0].text(x, y+0.5, y, ha='center') # 句子长度分析 length = data['sentence'].map(lambda x: len(x)) sns.histplot(x=length, ax=axes[1]) axes[1].set_title('句子长度分布') plt.show() # 少数类别构建马尔科夫链 def build_markov_chain(train_data): statistics = Counter(train_data['label']) labels = statistics.keys() markov_chain = {} for label in labels: # 获得当前标签所有文本数据 data_under_label = train_data[train_data['label'] == label]['sentence'] # 每个 sentence 分词处理 cut_lines = [jieba.lcut(sentence) for sentence in data_under_label] # 构建该 label 马尔科夫链 markov_chain[label] = {} index = 1 for line in cut_lines: for word in line[index:]: key = line[index - 1] if key not in markov_chain[label]: markov_chain[label][key] = [word] else: markov_chain[label][key].append(word) index += 1 index = 1 return markov_chain # 文本增强 def markov_chain_augmentation(train_data, markov_chain): statistics = Counter(train_data['label']) max_label, max_number = statistics.most_common(1)[0] unbanlanced_labels = [label for label in statistics.keys() if label != max_label] unbanlanced_number = [max_number - statistics[label] for label in unbanlanced_labels] # unbanlanced_number = [100 for label in unbanlanced_labels] train_data = train_data.to_numpy()[:, 1:].tolist() progress = tqdm(range(sum(unbanlanced_number)), desc='增强语料') for label, number in zip(unbanlanced_labels, unbanlanced_number): chain = markov_chain[label] for _ in range(number): length = random.randint(25, 100) start = random.choice(list(chain.keys())) sentence = [start] while len(sentence) < length: start = random.choice(chain.get(start, list(chain.keys()))) sentence.append(start) train_data.append([''.join(sentence), label]) progress.update() return pd.DataFrame(train_data, columns=['sentence', 'label']) def test(): train_data = pd.read_csv('data/train.csv') markov_chain = build_markov_chain(train_data) train_data = markov_chain_augmentation(train_data, markov_chain) data_plot(train_data) if __name__ == '__main__': test()
如果样本量较大的话,上面代码执行还是比较慢,可以使用多任务方式来为每个类别增强数据。
由马尔科夫链产生的部分文本内容如下:
更好的《给力支持一直傻笑,想去把扎嘿!!!!~~就是,这是多么和谐声音,奶牛》《飞吻》霍,超过888名侦探也可以的画卷。。我好!!我喜欢看了鲜芋仙,这样 工程师;三月去看看俺们受几个演员不说别的手机等我想你嘛买泳衣,拍节目提意见!!!!!支持孩之宝,他了 气氛了澳洲牛奶香啊从淘宝买海绵宝宝是太贴心了!!!!!谢谢啊!女人的!!!!你嘛?没问题吧呵呵。~~爷爷真是集了;三月去娶个冷菜,画下了,一部拖拉机!!!!!!《害羞》是不是直接撤销或每月制定一个!!!你》《真素有文化的宝宝,想去见见小花啊这本书!!!!!!!!又老 出点汗拍的知心朋友,押韵,呐……强迫症患者就是视觉冲级强烈,而是我还没出,周杰伦被时间看她打电话的好看的组织的人都给你还那么的包包@于一体的消息;假如是最后 镯子的钟云锅的家人在我预想的不客气哦!不许我还有,超q的,你上辈的宣传,孩子的乐感,名模小丑,我也好看 嘛!!!!!!!!!!!!!!!天天看见你该有个不折不扣的,你能屹立于明加难得一位叫什么问题的食物要多爱 甜心35元,我再次感受到喉咙什么东西噢!!!《可爱的不记得留意咯就是即使你一定会开心,我的《思考》《白虎头许坤非法经营案》突然就先祝他叫帅气,真的没得说,而是对