很多资料表明,通过文本数据增强也能够增强模型的分类性能。本篇文章总结几种文本数据增强的方法:
- 马尔科夫链文本增强
- 百度回译数据增强
- EDA 数据增强
1. 数据信息简单展示
接下来演示使用的数据为 json 格式,其共有 6 个类别,分别使用 0-5 的数字来表示,下面的代码对数据进行简单处理并可视化。
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
from pyhanlp import JClass
def data_plot(data):
_, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=100)
plt.subplots_adjust(wspace=0.5,hspace=0.6)
clases = Counter(data['label'])
sns.barplot(x=list(clases.keys()), y=list(clases.values()), ax=axes[0])
axes[0].set_title('类别样本数量')
axes[0].set_xlabel('class')
axes[0].set_ylabel('Count')
for x, y in clases.items():
axes[0].text(x, y+0.5, y, ha='center')
length = data['sentence'].map(lambda x: len(x))
sns.histplot(x=length, ax=axes[1])
axes[1].set_title('句子长度分布')
plt.show()
def data_split():
data = pd.read_json('data/senti-data.json')
data.columns = ['sentence', 'label']
normalizer = JClass('com.hankcs.hanlp.dictionary.other.CharTable')
def clean_data(sentence):
sentence = normalizer.convert(sentence)
sentence = ''.join(sentence.split())
return sentence
sentence = [clean_data(sentence) for sentence in data['sentence']]
data['sentence'] = sentence
data_plot(data)
print(data.info())
data.to_csv('data/train.csv')
if __name__ == '__main__':
data_split()
RangeIndex: 40133 entries, 0 to 40132 Data columns (total 2 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 sentence 40133 non-null object 1 label 40133 non-null int64 dtypes: int64(1), object(1) memory usage: 627.2+ KB None

我们可以看到不同类别数据分布不均,我们通过一些方法去增加少数类别的样本,使得不同类别数据量都能差不多。这里可以再补充一下,再进行数据增强时,一般都是对训练集进行数据增强,测试集仍然使用原始未增强的数据集。我们这里就假设拿到的数据全部为训练集数据。
2. 马尔科夫链数据增强
import pandas as pd
from collections import Counter
import jieba
import logging
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
jieba.setLogLevel(logging.CRITICAL)
def data_plot(data):
_, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=100)
plt.subplots_adjust(wspace=0.5,hspace=0.6)
# 类别样本数量
clases = Counter(data['label'])
sns.barplot(x=list(clases.keys()), y=list(clases.values()), ax=axes[0])
axes[0].set_title('类别样本数量')
axes[0].set_xlabel('class')
axes[0].set_ylabel('Count')
for x, y in clases.items():
axes[0].text(x, y+0.5, y, ha='center')
# 句子长度分析
length = data['sentence'].map(lambda x: len(x))
sns.histplot(x=length, ax=axes[1])
axes[1].set_title('句子长度分布')
plt.show()
# 少数类别构建马尔科夫链
def build_markov_chain(train_data):
statistics = Counter(train_data['label'])
labels = statistics.keys()
markov_chain = {}
for label in labels:
# 获得当前标签所有文本数据
data_under_label = train_data[train_data['label'] == label]['sentence']
# 每个 sentence 分词处理
cut_lines = [jieba.lcut(sentence) for sentence in data_under_label]
# 构建该 label 马尔科夫链
markov_chain[label] = {}
index = 1
for line in cut_lines:
for word in line[index:]:
key = line[index - 1]
if key not in markov_chain[label]:
markov_chain[label][key] = [word]
else:
markov_chain[label][key].append(word)
index += 1
index = 1
return markov_chain
# 文本增强
def markov_chain_augmentation(train_data, markov_chain):
statistics = Counter(train_data['label'])
max_label, max_number = statistics.most_common(1)[0]
unbanlanced_labels = [label for label in statistics.keys() if label != max_label]
unbanlanced_number = [max_number - statistics[label] for label in unbanlanced_labels]
# unbanlanced_number = [100 for label in unbanlanced_labels]
train_data = train_data.to_numpy()[:, 1:].tolist()
progress = tqdm(range(sum(unbanlanced_number)), desc='增强语料')
for label, number in zip(unbanlanced_labels, unbanlanced_number):
chain = markov_chain[label]
for _ in range(number):
length = random.randint(25, 100)
start = random.choice(list(chain.keys()))
sentence = [start]
while len(sentence) < length:
start = random.choice(chain.get(start, list(chain.keys())))
sentence.append(start)
train_data.append([''.join(sentence), label])
progress.update()
return pd.DataFrame(train_data, columns=['sentence', 'label'])
def test():
train_data = pd.read_csv('data/train.csv')
markov_chain = build_markov_chain(train_data)
train_data = markov_chain_augmentation(train_data, markov_chain)
data_plot(train_data)
if __name__ == '__main__':
test()
如果样本量较大的话,上面代码执行还是比较慢,可以使用多任务方式来为每个类别增强数据。

由马尔科夫链产生的部分文本内容如下:
更好的《给力支持一直傻笑,想去把扎嘿!!!!~~就是,这是多么和谐声音,奶牛》《飞吻》霍,超过888名侦探也可以的画卷。。我好!!我喜欢看了鲜芋仙,这样 工程师;三月去看看俺们受几个演员不说别的手机等我想你嘛买泳衣,拍节目提意见!!!!!支持孩之宝,他了 气氛了澳洲牛奶香啊从淘宝买海绵宝宝是太贴心了!!!!!谢谢啊!女人的!!!!你嘛?没问题吧呵呵。~~爷爷真是集了;三月去娶个冷菜,画下了,一部拖拉机!!!!!!《害羞》是不是直接撤销或每月制定一个!!!你》《真素有文化的宝宝,想去见见小花啊这本书!!!!!!!!又老 出点汗拍的知心朋友,押韵,呐……强迫症患者就是视觉冲级强烈,而是我还没出,周杰伦被时间看她打电话的好看的组织的人都给你还那么的包包@于一体的消息;假如是最后 镯子的钟云锅的家人在我预想的不客气哦!不许我还有,超q的,你上辈的宣传,孩子的乐感,名模小丑,我也好看 嘛!!!!!!!!!!!!!!!天天看见你该有个不折不扣的,你能屹立于明加难得一位叫什么问题的食物要多爱 甜心35元,我再次感受到喉咙什么东西噢!!!《可爱的不记得留意咯就是即使你一定会开心,我的《思考》《白虎头许坤非法经营案》突然就先祝他叫帅气,真的没得说,而是对

冀公网安备13050302001966号