Bahdanau Attention And Luong Attention

在带有注意力机制的 Encoder-Decoder 模型中存在很多注意力机制,本篇文章根据原始论文对 Bahdanau 注意力计算方法和 Luong 注意力计算方法进行总结。

  1. Bahdanau Attention
  2. Luong Attention

参考:https://www.zhihu.com/question/68482809/answer/1742071699

1. Bahdanau Attention

论文地址:《Neural Machine Translation By Jointly Learning To Align And Translate》

scoreij 表示解码器第 i-1 时刻隐藏状态和解码器第 j 时刻的分数。
1. si-1 解码器 i 时刻的上一时刻隐藏状态;
2. hj 编码器 j 时刻的时刻隐藏状态;
3. Wa 表示对 si-1 线性变换的参数;
4. Ua 表示对 hj 线性变换的参数;
5. va 表示对 tanh 线性变换的参数;

示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F


class BahdanauAttention(nn.Module):
    
    def __init__(self, encoder_hidden_dim, decoder_hidden_dim, attn_dim):
        
        super(BahdanauAttention, self).__init__()
        self.decoder_linear = nn.Linear(decoder_hidden_dim, attn_dim)
        self.encoder_linear = nn.Linear(encoder_hidden_dim, attn_dim)
        self.score_linear = nn.Linear(attn_dim, 1)


    def forward(self, value, query):

        # encoder_output 形状为 (batch_size, seq_len, encoder_hidden_dim)
        # decoder_hidden 形状为 (batch_size, 1, decoder_hidden_dim)

        # q 的形状为 (batch_size, seq_len, attn_dim)
        # k 的形状为 (batch_size, 1, attn_dim)
        q = self.decoder_linear(query)
        k = self.encoder_linear(value)

        # score 的形状为 (batch_size, seq_len, 1)
        # attn_weright 形状为 (batch_size, seq_len, 1)
        score = self.score_linear(torch.tanh(q + k))
        attn_weight = F.softmax(score, dim=1)

        # attn_tensor 的形状 ()
        attn_tensor = torch.sum(attn_weight * value, dim=1)

        return attn_tensor, attn_weight


def test():

    # 编码器输出张量: batch 为 32, seq_len 为 300, 每个词的维度为 256
    encoder_output = torch.randn(32, 300, 256)
    # 解码器隐藏状态: batch 为 32, seq_len 为 1, 每个词的维度为 256
    decoder_hidden = torch.randn(32, 1, 256)

    atttention = BahdanauAttention(encoder_hidden_dim=256, decoder_hidden_dim=256, attn_dim=64)
    attn_tensor, attn_weight = atttention(encoder_output, decoder_hidden)
    print(attn_tensor.shape)
    print(attn_weight.shape)


if __name__ == '__main__':
    test()

程序输出结果:

torch.Size([32, 256])
torch.Size([32, 300, 1])

2. Luong Attention

论文地址: Effective Approaches To Attention-Based Neural Machine Translation

Luong Attention 中输入的是编码器的各个时间步的隐藏状态输出,以及解码器当前时间步的隐藏状态。具体的 score 计算方式论文给出了如下三种方式:

1. 带横线的 hs 表示编码器的所有的隐藏状态输出; 2. ht 表示解码器当前时间步的输出隐藏状态,而不是上一个时刻的隐藏状态; 3. 计算得出的 score 值经过 softmax 得到注意力权重分布,最后得出 ct

示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F


class LuongAttention(nn.Module):
    
    def __init__(self, encoder_hidden_dim, decoder_hidden_dim, attn_dim):
        
        super(LuongAttention, self).__init__()
        self.encoder_linear = nn.Linear(encoder_hidden_dim, encoder_hidden_dim)
        self.linear1 = nn.Linear(encoder_hidden_dim, encoder_hidden_dim)
        self.linear2 = nn.Linear(encoder_hidden_dim, encoder_hidden_dim)


    def forward(self, value, query):

        # value 形状为 (batch_size, seq_len, encoder_hidden_dim)
        # query 形状为 (batch_size, 1, decoder_hidden_dim)

        # dot
        score = value @ query.transpose(1, 2)
        print(score.shape)

        # general
        score = self.encoder_linear(value) @ query.transpose(1, 2)
        print(score.shape)

        # concat
        score = self.linear2(torch.tanh(self.linear1(value + query)))
        print(score.shape)


        attn_weight = F.softmax(score, dim=1)
        attn_tensor = torch.sum(attn_weight * value, dim=1)

        return attn_tensor, attn_weight


def test():

    # 编码输出张量: batch 为 32, seq_len 为 300, 每个词的维度为 256
    encoder_output = torch.randn(32, 300, 256)
    # 解码器当前隐藏状态: batch 为 32, seq_len 为 1, 每个词的维度为 256
    decoder_hidden = torch.randn(32, 1, 256)

    atttention = LuongAttention(encoder_hidden_dim=256, decoder_hidden_dim=256, attn_dim=64)
    attn_tensor, attn_weight = atttention(encoder_output, decoder_hidden)
    print(attn_tensor.shape)
    print(attn_weight.shape)


if __name__ == '__main__':
    test()

程序输出结果:

torch.Size([32, 300, 1])
torch.Size([32, 300, 1])
torch.Size([32, 300, 256])
torch.Size([32, 256])
torch.Size([32, 300, 256])

未经允许不得转载:一亩三分地 » Bahdanau Attention And Luong Attention