Transformers 解码方法

在使用自回归模型做生成任务时,例如使用 GPT2 做生成任务,需要解码预测结果。会接触到以下几种解码方法:

  1. Greedy Search
  2. Beam Search
  3. Top-K Sampling
  4. Nucleus Sampling

接下来,就总结下这几种解码方法在 transformers 库中的使用,最主要用的是 generate 函数。
使用到的模型:临时中转 – NovelGeneratorGPT2 – checkpoint_-4.8810.pt

import torch
from transformers import BertTokenizer
from transformers import GPT2LMHeadModel
from transformers import GPT2Config

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = 'temp/checkpoint_-4.8810.pt'
objects = torch.load(model_path, map_location=device)
tokenizer = BertTokenizer.from_pretrained('model')
config = GPT2Config(vocab_size=tokenizer.vocab_size)
model = GPT2LMHeadModel(config=config)
model.load_state_dict(objects['model'])
model.eval()

start_text = '杨妃'
input_ids = tokenizer.encode(start_text, add_special_tokens=False, return_tensors='pt')
print(input_ids)

1. Greedy Search

在解码时,贪心搜索每次都选择当前时间步概率最高的词作为输出。它得出的序列一般不是概率最高。

def test01():

    # input_ids 表示起始文本
    # max_length 生成最大长度
    # early_stopping 是否碰到结束符提前停止, 需要设置 eos_token_id
    outputs = model.generate(input_ids=input_ids, max_length=100, early_stopping=False)
    # 将数字 token 还原为文本内容
    outputs = tokenizer.decode(outputs[0])
    # 去除内容中的空格
    outputs = ''.join(outputs.split())
    # 截取到最后一个句号
    end_pos = outputs.rfind('。') + 1
    # 输出内容
    print(outputs[:end_pos])

程序输出结果:

杨妃道:“善信正是太子侧妃郡王李恪,为师,小儿也正是汉中郡王李恪,不知乾道如何得知善信的身份?”道童听闻杨妃自承身份,于是回道:“家师昨日收到娘娘拜帖,便知娘娘与殿下今日来此,特命小道在此等候。

2. Beam Search

如果要找到一个全局最优的,也就是句子的似然概率最大的序列,可以穷举出所有可能的序列,并计算每个序列的似然概率,取概率最大的作为输出。但是,这样的解码过程效率太低了。如果产生一个 200 长度的序列,每个字/词有 1 万个候选词,穷举过程将会是灾难性的。Beam Search 则是在穷举和贪心搜索之间的一种这种的方法。

def test02():

    # num_beams 为设置的每次候选词数量
    outputs = model.generate(input_ids,
                             max_length=100,
                             num_beams=3,
                             early_stopping=False)
    outputs = tokenizer.decode(outputs[0])
    outputs = ''.join(outputs.split())
    end_pos = outputs.rfind('。') + 1
    print(outputs[:end_pos])

程序输出结果:

杨妃道:“娘娘要婢子给阿郎带句话:‘长孙家能有今日,靠的不是落雕弓,而长孙家的人,与太子的储位和长孙家未来富贵相比,一把落雕弓又算得了什么。

3. Top-K Sampling

Beam Search 产生的句子缺失了文本生成中的多样性。我们可以在每一步产生 k 个候选词,一般为概率最高的 k 个词作为候选,然后根据其概率值,随机从 k 个中选择一个作为当前时间步的输出。

def test03():
    # num_beams 为设置的每次候选词数量
    outputs = model.generate(input_ids, max_length=100, do_sample=True, top_k=10)
    outputs = tokenizer.decode(outputs[0])
    outputs = ''.join(outputs.split())
    end_pos = outputs.rfind('。') + 1
    print(outputs[:end_pos])

程序输出结果(每次生成结果):

杨妃道:“小妹放心,想必是太子给你一个连位置吧,陛下为‘恪’,你可敦须知娘和富人,只是外人,我举尽可迎娶妻,可汗的。”她李世民听了李恪的话,他的话,倒也未直接着杨妃道,眉头。

4. Nucleus Sampling

Top-K 每次将会选择固定个数的候选词,有时候候选词的概率分布较为平均的话,固定个数的候选词将会导致后面一些概率差不多词没有机会被选上。Nucleus Sampling 则是根据概率分布来确定候选词的数量,比如我们可以设置一个阈值概率综合 0.95, 从候选词中选择出一组候选词,这些词的概率总和大于等于 0.95,然后从这些词中候选。由于每个时间步词的概率分布不同,则候选词的数量也是不同的。gen

def test04():

    outputs = model.generate(input_ids, max_length=100, do_sample=True, top_p=0.95)
    outputs = tokenizer.decode(outputs[0])
    outputs = ''.join(outputs.split())
    end_pos = outputs.rfind('。') + 1
    print(outputs[:end_pos])

程序输出结果(每次生成结果不同):

杨妃道:“善信正是太子小儿也正是特命小人,小儿也正是汉中郡王李恪,不知乾道如何得知善信的身份?”道童听闻杨妃自承身份,于是回道:“家师昨日收到娘娘拜帖,便知娘娘与殿下今日来此,特命小道在此等候。

未经允许不得转载:一亩三分地 » Transformers 解码方法