文本数据增强 – 回译

很多资料表明,通过文本数据增强也能够增强模型的分类性能。本篇文章总结几种文本数据增强的方法:

  1. 马尔科夫链文本增强
  2. 百度回译数据增强
  3. EDA 数据增强

1. 数据信息简单展示

接下来演示使用的数据为 json 格式,其共有 6 个类别,分别使用 0-5 的数字来表示,下面的代码对数据进行简单处理并可视化。

import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
from pyhanlp import JClass


def data_plot(data):

    _, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=100)
    plt.subplots_adjust(wspace=0.5,hspace=0.6)

    clases = Counter(data['label'])
    sns.barplot(x=list(clases.keys()), y=list(clases.values()), ax=axes[0])
    axes[0].set_title('类别样本数量')
    axes[0].set_xlabel('class')
    axes[0].set_ylabel('Count')
    for x, y in clases.items():
        axes[0].text(x, y+0.5, y, ha='center')

    length = data['sentence'].map(lambda x: len(x))
    sns.histplot(x=length, ax=axes[1])
    axes[1].set_title('句子长度分布')

    plt.show()


def data_split():

    data = pd.read_json('data/senti-data.json')
    data.columns = ['sentence', 'label']
    normalizer = JClass('com.hankcs.hanlp.dictionary.other.CharTable')
    def clean_data(sentence):
        sentence = normalizer.convert(sentence)
        sentence = ''.join(sentence.split())
        return sentence
    sentence = [clean_data(sentence) for sentence in data['sentence']]
    data['sentence'] = sentence
    data_plot(data)
    print(data.info())
    data.to_csv('data/train.csv')

if __name__ == '__main__':
    data_split()

我们可以看到不同类别数据分布不均,我们通过一些方法去增加少数类别的样本,使得不同类别数据量都能差不多。这里可以再补充一下,再进行数据增强时,一般都是对训练集进行数据增强,测试集仍然使用原始未增强的数据集。我们这里就假设拿到的数据全部为训练集数据。

2. 回译数据增强方法

使用 google_trans_new 工具进行翻译时,总是出现超时,特别是短时内请求次数增多时,后面基本无法获得翻译结果。我这里使用百度的翻译接口,但是是有限制的,所以并未演示所有的数据的增强。对每个类别增强 10 条样本。如果不坐限制的话,默认是把类别样本增加自身大小的一倍,当然其他的样本也可以继续通过翻译的方式补充,也可以通过其他的数据增强方法来补充。

百度翻译接口的几个付费级别:

示例代码:

import pandas as pd
from collections import Counter
import logging
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import hashlib
import requests
import json
import time


def data_plot(data):

    _, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=100)
    plt.subplots_adjust(wspace=0.5,hspace=0.6)

    # 类别样本数量
    clases = Counter(data['label'])
    sns.barplot(x=list(clases.keys()), y=list(clases.values()), ax=axes[0])
    axes[0].set_title('类别样本数量')
    axes[0].set_xlabel('class')
    axes[0].set_ylabel('Count')
    for x, y in clases.items():
        axes[0].text(x, y+0.5, y, ha='center')

    # 句子长度分析
    length = data['sentence'].map(lambda x: len(x))
    sns.histplot(x=length, ax=axes[1])
    axes[1].set_title('句子长度分布')

    plt.show()


def translate(text, lang_src, lang_dst):
    if text == '':
        return ''
    # 开发者信息
    appid = ''  # 写自己的
    secret_key = ''  # 写自己

    query, lang_src, lang_dst = str(text), lang_src, lang_dst
    salt = str(random.randint(100, 10000))
    sign = (appid + query + salt + secret_key).encode('utf-8')
    sign = hashlib.md5(sign).hexdigest()
    request_url = 'http://api.fanyi.baidu.com/api/trans/vip/translate?'
    request_url = request_url + 'q=' + query + '&from=' + lang_src + '&to=' + lang_dst + '&appid=' + appid + '&salt=' + salt + '&sign=' + sign
    response = requests.get(request_url)
    response_json = response.content.decode('utf-8')
    response_dict = json.loads(response_json)
    if response_dict.get('error_code'):
        print('错误码:', response_dict['error_code'], '错误描述:', response_dict['error_msg'])
        return ''
    return response_dict['trans_result'][0]['dst']


def back_translation_augmentation(train_data):

    statistics = Counter(train_data['label'])
    max_label, max_number = statistics.most_common(1)[0]
    unbanlanced_labels = [label for label in set(statistics.keys()) if label != max_label]
    new_train_data = train_data.to_numpy()[:, 1:].tolist()
    for label in unbanlanced_labels:
        # 设置每个类别只翻译回译 10 条样本
        sentences = train_data[train_data['label'] == label]['sentence'][:10]
        for sentence in sentences:
            print(sentence)
            sentence = translate(sentence, 'zh', 'kor')
            sentence = translate(sentence, 'kor', 'en')
            sentence = translate(sentence, 'en', 'zh')
            print(sentence)
            print('-' * 100)
            new_train_data.append([sentence, label])

    return pd.DataFrame(new_train_data, columns=['sentence', 'label'])

def test():

    train_data = pd.read_csv('data/train.csv')
    train_data = back_translation_augmentation(train_data)
    data_plot(train_data)


if __name__ == '__main__':
    test()

通过回译生成的新的语料部分如下:

银色的罗马高跟鞋,圆球吊饰耳饰单带,个性十足,有非常抢眼!
银色罗马高跟鞋、球饰、耳环和腰带,个性十足,引人注目!
-------------------------------------------------------------------------------------
《爱你》祝福凯爷婷姐新婚快乐《心》,1016我们天津见!《蛋糕》
我爱你,祝福春天茉莉花郑姐姐的婚礼。心声,10月16日在天津见!糕饼
-------------------------------------------------------------------------------------
等考完试我会继续爱翰!
考试之后,我会永远爱你!
-------------------------------------------------------------------------------------
不过去年扭了脖子带来了好运气,希望这次也是!
去年扭伤的脖子带来了好运。我希望这次也一样!
-------------------------------------------------------------------------------------
我希望无论是我们80后,还是70后,都可以一直保持着一份童心,一份简单。
我希望我们能继续保持孩子般的天真和单纯,无论是80后还是70后。
-------------------------------------------------------------------------------------
《zongzi》《zongzi》额想吃粽子妈妈包的粽子特别想吃《哭泣女》
“粽子”,“粽子”想吃妈妈给我做的粽子,“哭哭啼啼的女孩”想吃得太多了。
-------------------------------------------------------------------------------------
翻译了一下,也希望能说给云南的同胞听,大家加油!
我希望我能把它送给我在云南的同胞。大家加油!
-------------------------------------------------------------------------------------
回复@dr小刀:《给力》
@剑博士:<力量>

未经允许不得转载:一亩三分地 » 文本数据增强 – 回译