Bert MHA 源码分析

我们在使用 Bert 模型时,对每一个 token 的表征计算都是通过其内部的自注意力机制来完成的,具体就是由 Bert 模型的 BertAttention 来负责自注意力计算,分析的实现代码是 transformers==4.22.2,下面是该类的实现代码:

class BertAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        # 多头自注意力计算层
        self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
        # 计算最终输出
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs

上面源码的 forward 函数,可以看到计算就分为两步,分别是多头自注意力计算,输出计算。由 init 函数中可以看到,该组件是由 BertSelfAttention 和 BertSelfOutput 两部分组成,我们接下来分析这两部分对输入究竟做了哪些计算过程。

1. BertSelfAttention

BertSelfAttention 是 Bert 多头自注意力计算层,该类主要包含三个函数:

  1. init 函数,初始化各种需要的组件
  2. transpose_for_scores 函数,用于在自注意力计算过程中的转置操作
  3. forward 函数,就是自注意力计算的过程

我们先看 init 函数,看看该类主要初始化了哪些组件,源代码以及注释如下:

def __init__(self, config, position_embedding_type=None):
    super().__init__()
    if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
        raise ValueError(
            f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
            f"heads ({config.num_attention_heads})"
        )

    # 从 BertConfig 对象读取配置的注意力头数量, 这里默认头数是: 12
    self.num_attention_heads = config.num_attention_heads
    
    # 计算每一个注意力头输出的向量维度
    # 我们输入的维度是 config.hidden_size=768, 除以 config.num_attention_heads=12 之后, 得到 64
    # 即: 每一个头输出的维度是 64 维,将 12 个头的维度拼接起来,会得到 768 的向量
    self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
    
    # 由于我们这里使用 batch 的概念来批量计算 12 个注意力头,这里就计算下 12 个头输出的向量维度是多少
    # 经过计算,self.all_head_size=768
    self.all_head_size = self.num_attention_heads * self.attention_head_size

    # 自注意力计算过程中,QKV 都是通过线性变换得到,这里初始化一个 (768, 12, 64) 的线性层
    # 由于线性层只有输入维度和输出维度,所以第二维度 self.all_head_size=12*64=768
    # 我们可以理解这个 768 行 768 列的矩阵中,768列实际上分成了 12 个区域,每个区域代表一个头的所有参数
    # 其他的 key 和 value 层与 query 同理,所以他们的形状都是一样的
    self.query = nn.Linear(config.hidden_size, self.all_head_size)
    self.key = nn.Linear(config.hidden_size, self.all_head_size)
    self.value = nn.Linear(config.hidden_size, self.all_head_size)

    # 下面部分可以先不用关注
    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
    self.position_embedding_type = position_embedding_type or getattr(
        config, "position_embedding_type", "absolute"
    )
    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
        self.max_position_embeddings = config.max_position_embeddings
        self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

    self.is_decoder = config.is_decoder

接下来,我们看下 transpose_for_scores 函数,该函数主要是对接下来进行的多头自注意力计算进行转置操作,便于实现矩阵运算,下面是其源代码实现以及注释:

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
	# 从 x.size()[:-1] 可以看到,我们将输入数据的最后一维丢掉,并增加了两个新的维度
	# 例如: 我们输入的数据维度是 (1, 3, 768),经过此函数的转置操作之后会变成 (1, 3, 12, 64)
	# 上面数据形状中的 1 表示 batch size 大小
    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    x = x.view(new_x_shape)
    # 这里再将第二维、第三维交换位置,由 (1, 3, 12, 64) -> (1, 12, 3, 64)
    # 把头这个维度提前,是为了方便后续分别计算每一个头的向量表示
    return x.permute(0, 2, 1, 3)

接下来,我们再分析下最重要的 forward 计算过程,由于源代码内容较多,我们只考虑当输入一个序列数据时,forward 经过了下面的 9 步骤计算。

我们假设输入的数据为:

from transformers import BertModel
from transformers import BertConfig


def test():
    BertModel(config=BertConfig())(input_ids=torch.tensor([[1, 2, 3]]))

if __name__ == '__main__':
    test()

具体的计算步骤下面已经标记出来:

    def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
#############################第1步开始##############################
# 这一步对输入的隐藏状态向量,按照我们前面假设输入的数据,这里 hidden_states 的形状为 (1, 3, 768)
# 我们前面提过 self.query 层包含了 12 个头的参数,所以经过变换之后,就得到输入 hidden_states 的 12 个 query 向量
# self.query 输入和输出维度都是 768, 所以经过变换后依然输出的维度是 mixed_query_layer=(1, 3, 768)
# 虽然形状和 hidden_states 一样,但是要知道数据表示的含义可就不同了,输出数据表示 12 个头的 query 向量表示
mixed_query_layer = self.query(hidden_states)
#############################第1步结束##############################
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
#############################第2步开始##############################
# 前面计算过了 query 向量,这里计算的是 12 个头的 key 和 value 向量表示
# 但是这里额外做了转置操作,目的是为了能够 query 和 key 进行矩阵运算
# key、value 经过转置之后,数据由 (1, 3, 768) 变成了 (1, 12, 3, 64)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
#############################第2步结束##############################
#############################第3步开始##############################
# 这个将 query 向量也进行转置,从 (1, 3, 768) 变成 (1, 12, 3, 64)
# 此时, query、key、value 都变成了 (1, 12, 3, 64)
# 这个数据可以理解为: 每个头的 qkv 都是一个 (3, 64) 的向量,3 表示 token 的数量
query_layer = self.transpose_for_scores(mixed_query_layer)
#############################第3步结束##############################
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
#############################第4步开始##############################
# 这里开始计算注意力分数,query_layer=(1, 12, 3, 64) 
# key_layer.transpose(-1, -2) 转置了最后两个维度之后由 (1, 12, 3, 64) 变为 (1, 12, 64, 3)
# 最终计算得到的 attention_scores 的形状为: (1, 12, 3, 64) @ (1, 12, 64, 3) = (1, 12, 3, 3)
# 这个 (3, 3) 每一行表示某个 token 对其他 token 的注意力分数
# 3 行表示 3 个 token 分别对其他 token 的注意力分数
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
#############################第4步结束##############################
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
#############################第5步开始##############################
# 这里对注意力分数进行了缩放 math.sqrt(self.attention_head_size) = math.sqrt(64) = 8
# 为什么要这么做?
# 因为后面要将分数变成概率表示,如果分数之间差值很大,就会导致计算概率时有些值变成了 0,使得 token 无法注意力到该 token
# 所以这里对分数进行缩放之后,再将其转换为注意力概率表示
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
#############################第5步结束##############################
#############################第6步开始##############################
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
# 猛一看这里对分数加上了 attention_mask,其实该 attention_mask 已经经过了处理
# 原来我们输入的 mask 是 [1, 1, 1], 这里 attention_mask 已经变成了 [0, 0, 0], 相当于没什么变化
# 但是假设我们的 mask 是 [1, 1, 0] 的话,也就是说有一个位置不需要计算,此时 attention_mask 会被处理成 [0, 0, -10000]
# 此时计算得到的 attention_scores 不需要计算的位置就是很小的负数,计算注意力概率值时,该位置相当于 0,相当于掩码了
attention_scores = attention_scores + attention_mask
#############################第6步结束##############################
# Normalize the attention scores to probabilities.
#############################第7步开始##############################
# 经过全面对分数的缩放,以及掩码的操作,这里可以正常计算注意力权重分布了
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
#############################第7步结束##############################
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
#############################第8步开始##############################
# 这一部分就是随机掩码掉一些 token 的注意力权重,减少一些计算量,可能还能增加模型的性能
# 我的理解是: 输入 200 个 token,我们可能没必须对所有 token 计算注意力,所以随机丢弃一小部分
attention_probs = self.dropout(attention_probs)
#############################第8步结束##############################
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
#############################第9步开始##############################
# 这一步就比较好理解了,注意力权重 (1, 12, 3, 3)@(1, 12, 3, 64)=(1, 12, 3, 64)
# 分别得到了每一个头的向量表示
context_layer = torch.matmul(attention_probs, value_layer)
# 这一步将形状由 (1, 12, 3, 64) 变为 (1, 3, 12, 64)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# 下面两步再将形状由 (1, 3, 12, 64) 变为 (1, 3, 768), 此时得到了多头自注意力的计算结果向量
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
#############################第9步结束##############################
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs

另外,需要注意的是,假设 12 个头,3 个 token 的话,那么:

  1. 12 个 head,表示每个 token 都有 12 组 query、key、value
  2. 第一个 token 的 12 个 head 中的第 1 个 query 会关注自身以及另外 2 个 token 的 key,计算得到注意力分数,并转换为概率表示,然后乘以 3 个 token 的 value 向量,最后将这 3 个向量直接相加,得到第 1 token 第 1 个 head 的向量表示
  3. 将 12 个 head 的注意力向量都计算出来,拼接起来就得到了第 1 个 token 的注意力表征向量

2. BertSelfOutput

class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
# 线性变换层
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
# 层归一化
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# 随机丢弃层
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
# 这里做了一个残差连接,将线性变换之后的结果+上原始输入,再进行层归一化计算
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states

关于归一化的操作,可以看下面的链接

未经允许不得转载:一亩三分地 » Bert MHA 源码分析
评论 (0)

9 + 4 =