文本数据增强 – 马尔科夫链

很多资料表明,通过文本数据增强也能够增强模型的分类性能。本篇文章总结几种文本数据增强的方法:

  1. 马尔科夫链文本增强
  2. 百度回译数据增强
  3. EDA 数据增强

1. 数据信息简单展示

接下来演示使用的数据为 json 格式,其共有 6 个类别,分别使用 0-5 的数字来表示,下面的代码对数据进行简单处理并可视化。

import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
from pyhanlp import JClass


def data_plot(data):

    _, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=100)
    plt.subplots_adjust(wspace=0.5,hspace=0.6)

    clases = Counter(data['label'])
    sns.barplot(x=list(clases.keys()), y=list(clases.values()), ax=axes[0])
    axes[0].set_title('类别样本数量')
    axes[0].set_xlabel('class')
    axes[0].set_ylabel('Count')
    for x, y in clases.items():
        axes[0].text(x, y+0.5, y, ha='center')

    length = data['sentence'].map(lambda x: len(x))
    sns.histplot(x=length, ax=axes[1])
    axes[1].set_title('句子长度分布')

    plt.show()


def data_split():

    data = pd.read_json('data/senti-data.json')
    data.columns = ['sentence', 'label']
    normalizer = JClass('com.hankcs.hanlp.dictionary.other.CharTable')
    def clean_data(sentence):
        sentence = normalizer.convert(sentence)
        sentence = ''.join(sentence.split())
        return sentence
    sentence = [clean_data(sentence) for sentence in data['sentence']]
    data['sentence'] = sentence
    data_plot(data)
    print(data.info())
    data.to_csv('data/train.csv')

if __name__ == '__main__':
    data_split()
RangeIndex: 40133 entries, 0 to 40132
Data columns (total 2 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   sentence  40133 non-null  object
 1   label     40133 non-null  int64 
dtypes: int64(1), object(1)
memory usage: 627.2+ KB
None

我们可以看到不同类别数据分布不均,我们通过一些方法去增加少数类别的样本,使得不同类别数据量都能差不多。这里可以再补充一下,再进行数据增强时,一般都是对训练集进行数据增强,测试集仍然使用原始未增强的数据集。我们这里就假设拿到的数据全部为训练集数据。

2. 马尔科夫链数据增强

import pandas as pd
from collections import Counter
import jieba
import logging
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

jieba.setLogLevel(logging.CRITICAL)


def data_plot(data):

    _, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=100)
    plt.subplots_adjust(wspace=0.5,hspace=0.6)

    # 类别样本数量
    clases = Counter(data['label'])
    sns.barplot(x=list(clases.keys()), y=list(clases.values()), ax=axes[0])
    axes[0].set_title('类别样本数量')
    axes[0].set_xlabel('class')
    axes[0].set_ylabel('Count')
    for x, y in clases.items():
        axes[0].text(x, y+0.5, y, ha='center')

    # 句子长度分析
    length = data['sentence'].map(lambda x: len(x))
    sns.histplot(x=length, ax=axes[1])
    axes[1].set_title('句子长度分布')

    plt.show()


# 少数类别构建马尔科夫链
def build_markov_chain(train_data):

    statistics = Counter(train_data['label'])
    labels = statistics.keys()
    markov_chain = {}

    for label in labels:
        # 获得当前标签所有文本数据
        data_under_label = train_data[train_data['label'] == label]['sentence']
        # 每个 sentence 分词处理
        cut_lines = [jieba.lcut(sentence) for sentence in data_under_label]
        # 构建该 label 马尔科夫链
        markov_chain[label] = {}
        index = 1
        for line in cut_lines:
            for word in line[index:]:
                key = line[index - 1]
                if key not in markov_chain[label]:
                    markov_chain[label][key] = [word]
                else:
                    markov_chain[label][key].append(word)
                index += 1
            index = 1


    return markov_chain


# 文本增强
def markov_chain_augmentation(train_data, markov_chain):

    statistics = Counter(train_data['label'])
    max_label, max_number = statistics.most_common(1)[0]
    unbanlanced_labels = [label for label in statistics.keys() if label != max_label]
    unbanlanced_number = [max_number - statistics[label] for label in unbanlanced_labels]
    # unbanlanced_number = [100 for label in unbanlanced_labels]

    train_data = train_data.to_numpy()[:, 1:].tolist()
    progress = tqdm(range(sum(unbanlanced_number)), desc='增强语料')
    for label, number in zip(unbanlanced_labels, unbanlanced_number):
        chain = markov_chain[label]
        for _ in range(number):
            length = random.randint(25, 100)
            start = random.choice(list(chain.keys()))
            sentence = [start]
            while len(sentence) < length:
                start = random.choice(chain.get(start, list(chain.keys())))
                sentence.append(start)
            train_data.append([''.join(sentence), label])
            progress.update()

    return pd.DataFrame(train_data, columns=['sentence', 'label'])


def test():

    train_data = pd.read_csv('data/train.csv')
    markov_chain = build_markov_chain(train_data)
    train_data = markov_chain_augmentation(train_data, markov_chain)
    data_plot(train_data)


if __name__ == '__main__':
    test()

如果样本量较大的话,上面代码执行还是比较慢,可以使用多任务方式来为每个类别增强数据。

由马尔科夫链产生的部分文本内容如下:

更好的《给力支持一直傻笑,想去把扎嘿!!!!~~就是,这是多么和谐声音,奶牛》《飞吻》霍,超过888名侦探也可以的画卷。。我好!!我喜欢看了鲜芋仙,这样

工程师;三月去看看俺们受几个演员不说别的手机等我想你嘛买泳衣,拍节目提意见!!!!!支持孩之宝,他了

气氛了澳洲牛奶香啊从淘宝买海绵宝宝是太贴心了!!!!!谢谢啊!女人的!!!!你嘛?没问题吧呵呵。~~爷爷真是集了;三月去娶个冷菜,画下了,一部拖拉机!!!!!!《害羞》是不是直接撤销或每月制定一个!!!你》《真素有文化的宝宝,想去见见小花啊这本书!!!!!!!!又老

出点汗拍的知心朋友,押韵,呐……强迫症患者就是视觉冲级强烈,而是我还没出,周杰伦被时间看她打电话的好看的组织的人都给你还那么的包包@于一体的消息;假如是最后

镯子的钟云锅的家人在我预想的不客气哦!不许我还有,超q的,你上辈的宣传,孩子的乐感,名模小丑,我也好看

嘛!!!!!!!!!!!!!!!天天看见你该有个不折不扣的,你能屹立于明加难得一位叫什么问题的食物要多爱

甜心35元,我再次感受到喉咙什么东西噢!!!《可爱的不记得留意咯就是即使你一定会开心,我的《思考》《白虎头许坤非法经营案》突然就先祝他叫帅气,真的没得说,而是对

未经允许不得转载:一亩三分地 » 文本数据增强 – 马尔科夫链