SentencePiece 是一种用于文本处理的工具,特别适用于基于神经网络的文本生成系统。它的主要功能是将文本分割成更小的单位(称为子词单元),这些子词单元可以是完整的单词、部分单词,甚至是单个字符。
- 灵活的词汇表: SentencePiece 允许我们在训练神经网络之前预先确定词汇表的大小。这对于控制模型的复杂度非常有用。
- 处理未知词: 它可以处理训练数据中未出现过的词,通过将它们分解成子词单元来表示。
- 语言无关性: SentencePiece 不依赖于特定的语言,可以用于多种语言的文本处理。
- 端到端系统: 它可以直接从原始文本进行训练,不需要额外的预处理或后处理步骤。
SentencePiece 主要使用了两种技术:
- BPE : 通过迭代合并频繁出现的子序列来构建词汇表。
- UnigGram: 基于词频的语言模型,用于对生成的子词序列进行评分,从而选择最佳的分割方式。
pip install sentencepiece
GitHub:https://github.com/google/sentencepiece
使用示例代码:
from sentencepiece import SentencePieceTrainer from sentencepiece import SentencePieceProcessor # 1. 训练 def test01(): SentencePieceTrainer.train(input='corpus.txt', # 指定输出模型的前缀名称。模型文件包含两个文件:model_prefix.model 和 model_prefix.vocab model_prefix='model/tokenizer', # 指定输入文件的格式。可以是 'text'(默认,按行分隔的句子)或者 'tsv'(Tab 分隔的文件,第一列为句子,其他列可选) input_format='text', # 指定模型涵盖的字符的百分比 character_coverage=0.99, # 词汇表的大小,即模型最终生成的分词单位数量。包括特殊符号(如 <unk>)在内 vocab_size=163, # 指定模型类型。支持四种模型:unigram、bpe、char、word model_type='bpe', # 是否在训练前对输入的句子进行随机打乱 shuffle_input_sentence=True, # 指定 <pad> 等特殊标记 ID。设置为 -1 时,表示该符号不在词汇表中 pad_id=0, bos_id=1, eos_id=2, # 指定 <unk>(未知标记)的 ID unk_id=3, # 定义用户自定义的特殊符号。这些符号将被包含在词汇表中,且不会被进一步分词处理 user_defined_symbols=['<user>', '<system>', '<asistant>'], # 指定控制符号(如 <cls> 等)。这些符号用于控制模型的行为 control_symbols=['|CLS|', '|SEP|'], # 当模型遇到未登录词时,它将使用 |unk| 来表示这些词 unk_surface='|unk|', # 指定文本标准化规则。支持:'nmt_nfkc':标准 NFKC 正规化,用于去除不必要的符号。'identity':不进行任何标准化。 normalization_rule_name='nmt_nfkc') # 2. 加载 def test02(): # 加载方法一 tokenizer = SentencePieceProcessor() tokenizer.load('model/tokenizer.model') # 加载方法二 tokenizer = SentencePieceProcessor(model_file='model/tokenizer.model') print('词表大小:', tokenizer.vocab_size()) # 3. 编码 def test03(): tokenizer = SentencePieceProcessor(model_file='model/tokenizer.model') inputs = tokenizer.Encode(input=['郑钦文仍然创造僻'], # 指定输出的类型。可以输出 piece 的索引(int)或文本(str)。 out_type=str, # 是否在输出序列的开头添加 <bos>(句子开始标记) add_bos=True, # 是否在输出序列的末尾添加 <eos>(句子结束标记) add_eos=True, # 是否对输出的子词序列进行反转 reverse=False, # 设置为 True,遇到未登录词时则使用 unk_surface 代替,否则返回 unk_id 对应的 ID。 emit_unk_piece=True) print(inputs) # 4. 解码 def test04(): tokenizer = SentencePieceProcessor(model_file='model/tokenizer.model') outputs = tokenizer.Decode([[13, 87, 43, 56, 12]], # str 字符串类型 # bytes 字节类型 # 'serialized_proto' 一种高效的二进制格式,通常用于数据存储和网络传输。 # 'immutable_proto' 解码后得到的输出内容为不可变的协议缓冲区对象 out_type='serialized_proto') print(outputs) # print(outputs[0].text) # print(outputs[0].score) # for piece in outputs[0].pieces: # print(piece) if __name__ == '__main__': test04()