PyTorch 构建词表

在解决 NLP 任务之前, 首先就要构建自己的词表。词表的作用就是给定语料,将文本中的以字为单位、或者以词为单位转换为整数序号,该序号可用于在词嵌入的 lookup table 中搜索词向量。

接下来,我们介绍下词表的构建过程,当然每个人的词表构建过程中对语料的某些处理细节不同,但是大步骤基本都是相同的。我们将使用 LCSTS 数据集构建词表。

LCSTS 数据集是哈工大基于新闻媒体在微博上发布的新闻构建的,内容包含:

  1. 总样本有 210 万样本
  2. 一篇短文(约100个字符)对应一篇摘要(约20个字符)

共有三个文件,如下所示:

文件内的每条样本如下:

我们就编写 Vocab 类,对上图中的 short_text、summary 内容构建词表。

import glob
import re
import jieba
import time
from collections import Counter
from multiprocessing import Pool
from multiprocessing import cpu_count
import random
from os.path import exists


def timer(func):

    def inner(*args, **kwargs):

        print(args[0].flag)
        print('function [%s] starts runing' % (func.__name__))
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print('function [%s] executed for %.2f seconds' % (func.__name__, end - start))
        print('-' * 51)
        return result
    return inner


# 分词函数
def cut_word(data):
    """
    :param data: 要分词的多行文本
    :return: 分词之后的结果
    """

    # 加载停用词
    stop_words = []
    for line in open('data/stopword.txt'):
        stop_words.append(line)

    content_word = []
    for line in data:
        words = jieba.lcut(line)
        content_word.append(words)
        # 停用词过滤
        # temp = []
        # for word in words:
        #     if word not in stop_words:
        #         temp.append(word)
        # content_word.append(temp)
    return content_word


# 统计词频函数
def statistics_words(data):

    word_freq = Counter()
    for words in data:
        word_freq += Counter(words)
    return word_freq


# 并行结巴分词
def parallel(task, data):

    # 进程池开启
    pool = Pool()
    cut_word_results = pool.map(task, data)
    pool.close()
    pool.join()

    return cut_word_results


def split_data(data, part_num):
    """
    :param data: 需要拆分的数据
    :param part_num: 将数据拆分成几块
    :return: 返回块列表
    """

    data_number = len(data)
    step, rest = divmod(data_number, part_num)
    blocks = [data[start * step:  start * step + step + 1] for start in range(part_num)]
    # 如果无法按块的数量整除,则剩余部分数据追加到最后一个块中
    if rest > 0:
        blocks[-1].extend(data[-rest:])

    return blocks



# 设置结巴分词不显示日志
jieba.setLogLevel(20)


class Vocab:

    # 原始文本分词之后的内容
    SHORT_SAVE_PATH = 'vocab/short_%s.txt'
    # 基于原始文本构建的词表
    VOCAB_SAVE_PATH = 'vocab/vocab_%s.txt'

    def __init__(self, flag='encoder'):
        """
        :param flag: encoder 加载原始文本, decoder 加载摘要文本
        """

        # 处理原文还是摘要
        self.flag = flag
        # 文本内容路径
        self.short_path = Vocab.SHORT_SAVE_PATH % self.flag
        self.vocab_path = Vocab.VOCAB_SAVE_PATH % self.flag
        # 获得数据集文件列表
        self.filenmes = glob.glob('data/DATA/*.txt')

        if exists(self.short_path) and exists(self.vocab_path):
            self.load()
        else:
            self.build()

    def load(self):

        # 存储词到索引的映射
        self.word_to_index = {}
        # 存储索引到词的映射
        self.index_to_word = {}
        # 存储词在语料中的频数
        self.freq_of_word = {}

        # 加载词表
        for line in open(self.vocab_path, 'r'):
            word, idx, freq = line.split()
            idx, freq = int(idx), int(freq)
            self.word_to_index[word] = idx
            self.index_to_word[idx] = word
            self.freq_of_word[word] = freq

        # 加载数据
        self.content_words = []
        for line in open(self.short_path, 'r'):
            self.content_words.append(line.split())

        # 开始记录索引
        self.token_index = 0

    def build(self):

        # 存储词到索引的映射
        self.word_to_index = {'PAD': 0, 'UNK': 1, 'SOS': 2, 'EOS': 3}
        # 存储索引到词的映射
        self.index_to_word = {value: key for key, value in self.word_to_index.items()}
        # 存储词在语料中的频数
        self.freq_of_word = Counter()
        # 开始记录索引
        self.token_index = len(self.word_to_index)
        # 语料的分词结果
        self.content_words = []

        # 1. 加载文件内容
        contents = self._load_from_txt()
        # 2. 分词处理
        self._cut_word(contents)
        # 3. 构建词表
        self._build_vocab()
        # 4. 存储词表
        self._save_dict()

    @timer
    def _load_from_txt(self):
        """读取文件内容"""

        contents = []
        for filename in self.filenmes:
            content = open(filename, 'r').read()
            # 提取短文内容
            if self.flag == 'encoder':
                match_content = re.findall(r'<short_text>(.*?)</short_text>', content, re.S)
            if self.flag == 'decoder':
                match_content = re.findall(r'<summary>(.*?)</summary>', content, re.S)
            for content in match_content:
                contents.append(content.strip())

        return contents

    @timer
    def _cut_word(self, contents):
        """内容分词"""

        # 1. 拆分数据
        blocks = split_data(contents, 12)

        # 2. 并行分词
        cut_word_result = parallel(cut_word, blocks)

        # 合并多进程分词结果
        for result in cut_word_result:
            self.content_words.extend(result)

    @timer
    def _build_vocab(self):
        """构建词典"""

        # 1. 词频数据分块
        blocks = split_data(self.content_words, 12)

        # 2. 并行统计词频
        statistics_results = parallel(statistics_words, blocks)

        # 3. 合并统计结果
        for result in statistics_results:
            self.freq_of_word += result

        # 4. 构建词索引映射
        for word, _ in self.freq_of_word.items():
            if word not in self.word_to_index:
                self.word_to_index[word] = self.token_index
                self.index_to_word[self.token_index] = word
                self.token_index += 1

    @timer
    def _save_dict(self):
        """存储词表"""

        with open(self.vocab_path, 'w') as file:
            for word, idx in self.word_to_index.items():
                if word.strip() == '':
                    continue
                file.write("%s %s %s\n" % (word, idx, self.freq_of_word[word]))

        with open(self.short_path, 'w') as file:
            for content in self.content_words:
                file.write(' '.join(content) + '\n')


if __name__ == '__main__':

    encoder_vocab = Vocab('encoder')
    decoder_vocab = Vocab('decoder')
    print('encoder_vocab:', len(encoder_vocab.word_to_index))
    print('decoder_vocab:', len(decoder_vocab.word_to_index))

在对语料库构建词表过程中,一开始发现整个执行过程极其的慢,通过分析程序,对分词、统计词频的实现增加进程池并发,提升了至少 8 倍的效率。

未经允许不得转载:一亩三分地 » PyTorch 构建词表