Bert MHA 源码分析

我们在使用 Bert 模型时,对每一个 token 的表征计算都是通过其内部的自注意力机制来完成的,具体就是由 Bert 模型的 BertAttention 来负责自注意力计算,分析的实现代码是 transformers==4.22.2,下面是该类的实现代码:

class BertAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None):
        super().__init__()
        # 多头自注意力计算层
        self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
        # 计算最终输出
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs

上面源码的 forward 函数,可以看到计算就分为两步,分别是多头自注意力计算,输出计算。由 init 函数中可以看到,该组件是由 BertSelfAttention 和 BertSelfOutput 两部分组成,我们接下来分析这两部分对输入究竟做了哪些计算过程。

1. BertSelfAttention

BertSelfAttention 是 Bert 多头自注意力计算层,该类主要包含三个函数:

  1. init 函数,初始化各种需要的组件
  2. transpose_for_scores 函数,用于在自注意力计算过程中的转置操作
  3. forward 函数,就是自注意力计算的过程

我们先看 init 函数,看看该类主要初始化了哪些组件,源代码以及注释如下:

def __init__(self, config, position_embedding_type=None):
    super().__init__()
    if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
        raise ValueError(
            f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
            f"heads ({config.num_attention_heads})"
        )

    # 从 BertConfig 对象读取配置的注意力头数量, 这里默认头数是: 12
    self.num_attention_heads = config.num_attention_heads
    
    # 计算每一个注意力头输出的向量维度
    # 我们输入的维度是 config.hidden_size=768, 除以 config.num_attention_heads=12 之后, 得到 64
    # 即: 每一个头输出的维度是 64 维,将 12 个头的维度拼接起来,会得到 768 的向量
    self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
    
    # 由于我们这里使用 batch 的概念来批量计算 12 个注意力头,这里就计算下 12 个头输出的向量维度是多少
    # 经过计算,self.all_head_size=768
    self.all_head_size = self.num_attention_heads * self.attention_head_size

    # 自注意力计算过程中,QKV 都是通过线性变换得到,这里初始化一个 (768, 12, 64) 的线性层
    # 由于线性层只有输入维度和输出维度,所以第二维度 self.all_head_size=12*64=768
    # 我们可以理解这个 768 行 768 列的矩阵中,768列实际上分成了 12 个区域,每个区域代表一个头的所有参数
    # 其他的 key 和 value 层与 query 同理,所以他们的形状都是一样的
    self.query = nn.Linear(config.hidden_size, self.all_head_size)
    self.key = nn.Linear(config.hidden_size, self.all_head_size)
    self.value = nn.Linear(config.hidden_size, self.all_head_size)

    # 下面部分可以先不用关注
    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
    self.position_embedding_type = position_embedding_type or getattr(
        config, "position_embedding_type", "absolute"
    )
    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
        self.max_position_embeddings = config.max_position_embeddings
        self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

    self.is_decoder = config.is_decoder

接下来,我们看下 transpose_for_scores 函数,该函数主要是对接下来进行的多头自注意力计算进行转置操作,便于实现矩阵运算,下面是其源代码实现以及注释:

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
	# 从 x.size()[:-1] 可以看到,我们将输入数据的最后一维丢掉,并增加了两个新的维度
	# 例如: 我们输入的数据维度是 (1, 3, 768),经过此函数的转置操作之后会变成 (1, 3, 12, 64)
	# 上面数据形状中的 1 表示 batch size 大小
    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    x = x.view(new_x_shape)
    # 这里再将第二维、第三维交换位置,由 (1, 3, 12, 64) -> (1, 12, 3, 64)
    # 把头这个维度提前,是为了方便后续分别计算每一个头的向量表示
    return x.permute(0, 2, 1, 3)

接下来,我们再分析下最重要的 forward 计算过程,由于源代码内容较多,我们只考虑当输入一个序列数据时,forward 经过了下面的 9 步骤计算。

我们假设输入的数据为:

from transformers import BertModel
from transformers import BertConfig


def test():
    BertModel(config=BertConfig())(input_ids=torch.tensor([[1, 2, 3]]))

if __name__ == '__main__':
    test()

具体的计算步骤下面已经标记出来:

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:

    	#############################第1步开始##############################
    	# 这一步对输入的隐藏状态向量,按照我们前面假设输入的数据,这里 hidden_states 的形状为 (1, 3, 768)
    	# 我们前面提过 self.query 层包含了 12 个头的参数,所以经过变换之后,就得到输入 hidden_states 的 12 个 query 向量
    	# self.query 输入和输出维度都是 768, 所以经过变换后依然输出的维度是 mixed_query_layer=(1, 3, 768)
    	# 虽然形状和 hidden_states 一样,但是要知道数据表示的含义可就不同了,输出数据表示 12 个头的 query 向量表示
        mixed_query_layer = self.query(hidden_states)
        #############################第1步结束##############################

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_layer = past_key_value[0]
            value_layer = past_key_value[1]
            attention_mask = encoder_attention_mask
        elif is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
        	#############################第2步开始##############################
        	# 前面计算过了 query 向量,这里计算的是 12 个头的 key 和 value 向量表示
        	# 但是这里额外做了转置操作,目的是为了能够 query 和 key 进行矩阵运算
        	# key、value 经过转置之后,数据由 (1, 3, 768) 变成了 (1, 12, 3, 64)
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            #############################第2步结束##############################


        #############################第3步开始##############################
        # 这个将 query 向量也进行转置,从 (1, 3, 768) 变成 (1, 12, 3, 64)
        # 此时, query、key、value 都变成了 (1, 12, 3, 64)
        # 这个数据可以理解为: 每个头的 qkv 都是一个 (3, 64) 的向量,3 表示 token 的数量
        query_layer = self.transpose_for_scores(mixed_query_layer)
        #############################第3步结束##############################

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_layer, value_layer)


        # Take the dot product between "query" and "key" to get the raw attention scores.
        #############################第4步开始##############################
        # 这里开始计算注意力分数,query_layer=(1, 12, 3, 64) 
        # key_layer.transpose(-1, -2) 转置了最后两个维度之后由 (1, 12, 3, 64) 变为 (1, 12, 64, 3)
        # 最终计算得到的 attention_scores 的形状为: (1, 12, 3, 64) @ (1, 12, 64, 3) = (1, 12, 3, 3)
        # 这个 (3, 3) 每一行表示某个 token 对其他 token 的注意力分数
        # 3 行表示 3 个 token 分别对其他 token 的注意力分数
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        #############################第4步结束##############################


        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        #############################第5步开始##############################
        # 这里对注意力分数进行了缩放 math.sqrt(self.attention_head_size) = math.sqrt(64) = 8
        # 为什么要这么做?
        # 因为后面要将分数变成概率表示,如果分数之间差值很大,就会导致计算概率时有些值变成了 0,使得 token 无法注意力到该 token
        # 所以这里对分数进行缩放之后,再将其转换为注意力概率表示
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        #############################第5步结束##############################


        #############################第6步开始##############################
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            # 猛一看这里对分数加上了 attention_mask,其实该 attention_mask 已经经过了处理
            # 原来我们输入的 mask 是 [1, 1, 1], 这里 attention_mask 已经变成了 [0, 0, 0], 相当于没什么变化
            # 但是假设我们的 mask 是 [1, 1, 0] 的话,也就是说有一个位置不需要计算,此时 attention_mask 会被处理成 [0, 0, -10000]
            # 此时计算得到的 attention_scores 不需要计算的位置就是很小的负数,计算注意力概率值时,该位置相当于 0,相当于掩码了
            attention_scores = attention_scores + attention_mask
        #############################第6步结束##############################


        # Normalize the attention scores to probabilities.
        #############################第7步开始##############################
        # 经过全面对分数的缩放,以及掩码的操作,这里可以正常计算注意力权重分布了
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        #############################第7步结束##############################


        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

        #############################第8步开始##############################
        # 这一部分就是随机掩码掉一些 token 的注意力权重,减少一些计算量,可能还能增加模型的性能
        # 我的理解是: 输入 200 个 token,我们可能没必须对所有 token 计算注意力,所以随机丢弃一小部分
        attention_probs = self.dropout(attention_probs)
        #############################第8步结束##############################

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        #############################第9步开始##############################
        # 这一步就比较好理解了,注意力权重 (1, 12, 3, 3)@(1, 12, 3, 64)=(1, 12, 3, 64)
        # 分别得到了每一个头的向量表示
        context_layer = torch.matmul(attention_probs, value_layer)
        # 这一步将形状由 (1, 12, 3, 64) 变为 (1, 3, 12, 64)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # 下面两步再将形状由 (1, 3, 12, 64) 变为 (1, 3, 768), 此时得到了多头自注意力的计算结果向量
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        #############################第9步结束##############################

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

另外,需要注意的是,假设 12 个头,3 个 token 的话,那么:

  1. 12 个 head,表示每个 token 都有 12 组 query、key、value
  2. 第一个 token 的 12 个 head 中的第 1 个 query 会关注自身以及另外 2 个 token 的 key,计算得到注意力分数,并转换为概率表示,然后乘以 3 个 token 的 value 向量,最后将这 3 个向量直接相加,得到第 1 token 第 1 个 head 的向量表示
  3. 将 12 个 head 的注意力向量都计算出来,拼接起来就得到了第 1 个 token 的注意力表征向量

2. BertSelfOutput

class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()

        # 线性变换层
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # 层归一化
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # 随机丢弃层
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        # 这里做了一个残差连接,将线性变换之后的结果+上原始输入,再进行层归一化计算
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

关于归一化的操作,可以看下面的链接

未经允许不得转载:一亩三分地 » Bert MHA 源码分析
评论 (0)

8 + 9 =