我们在使用 Bert 模型时,对每一个 token 的表征计算都是通过其内部的自注意力机制来完成的,具体就是由 Bert 模型的 BertAttention 来负责自注意力计算,分析的实现代码是 transformers==4.22.2,下面是该类的实现代码:
class BertAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() # 多头自注意力计算层 self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) # 计算最终输出 self.output = BertSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs
上面源码的 forward 函数,可以看到计算就分为两步,分别是多头自注意力计算,输出计算。由 init 函数中可以看到,该组件是由 BertSelfAttention 和 BertSelfOutput 两部分组成,我们接下来分析这两部分对输入究竟做了哪些计算过程。
1. BertSelfAttention
BertSelfAttention 是 Bert 多头自注意力计算层,该类主要包含三个函数:
- init 函数,初始化各种需要的组件
- transpose_for_scores 函数,用于在自注意力计算过程中的转置操作
- forward 函数,就是自注意力计算的过程
我们先看 init 函数,看看该类主要初始化了哪些组件,源代码以及注释如下:
def __init__(self, config, position_embedding_type=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) # 从 BertConfig 对象读取配置的注意力头数量, 这里默认头数是: 12 self.num_attention_heads = config.num_attention_heads # 计算每一个注意力头输出的向量维度 # 我们输入的维度是 config.hidden_size=768, 除以 config.num_attention_heads=12 之后, 得到 64 # 即: 每一个头输出的维度是 64 维,将 12 个头的维度拼接起来,会得到 768 的向量 self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 由于我们这里使用 batch 的概念来批量计算 12 个注意力头,这里就计算下 12 个头输出的向量维度是多少 # 经过计算,self.all_head_size=768 self.all_head_size = self.num_attention_heads * self.attention_head_size # 自注意力计算过程中,QKV 都是通过线性变换得到,这里初始化一个 (768, 12, 64) 的线性层 # 由于线性层只有输入维度和输出维度,所以第二维度 self.all_head_size=12*64=768 # 我们可以理解这个 768 行 768 列的矩阵中,768列实际上分成了 12 个区域,每个区域代表一个头的所有参数 # 其他的 key 和 value 层与 query 同理,所以他们的形状都是一样的 self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) # 下面部分可以先不用关注 self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = position_embedding_type or getattr( config, "position_embedding_type", "absolute" ) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder
接下来,我们看下 transpose_for_scores 函数,该函数主要是对接下来进行的多头自注意力计算进行转置操作,便于实现矩阵运算,下面是其源代码实现以及注释:
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: # 从 x.size()[:-1] 可以看到,我们将输入数据的最后一维丢掉,并增加了两个新的维度 # 例如: 我们输入的数据维度是 (1, 3, 768),经过此函数的转置操作之后会变成 (1, 3, 12, 64) # 上面数据形状中的 1 表示 batch size 大小 new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) # 这里再将第二维、第三维交换位置,由 (1, 3, 12, 64) -> (1, 12, 3, 64) # 把头这个维度提前,是为了方便后续分别计算每一个头的向量表示 return x.permute(0, 2, 1, 3)
接下来,我们再分析下最重要的 forward 计算过程,由于源代码内容较多,我们只考虑当输入一个序列数据时,forward 经过了下面的 9 步骤计算。
我们假设输入的数据为:
from transformers import BertModel from transformers import BertConfig def test(): BertModel(config=BertConfig())(input_ids=torch.tensor([[1, 2, 3]])) if __name__ == '__main__': test()
具体的计算步骤下面已经标记出来:
def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: #############################第1步开始############################## # 这一步对输入的隐藏状态向量,按照我们前面假设输入的数据,这里 hidden_states 的形状为 (1, 3, 768) # 我们前面提过 self.query 层包含了 12 个头的参数,所以经过变换之后,就得到输入 hidden_states 的 12 个 query 向量 # self.query 输入和输出维度都是 768, 所以经过变换后依然输出的维度是 mixed_query_layer=(1, 3, 768) # 虽然形状和 hidden_states 一样,但是要知道数据表示的含义可就不同了,输出数据表示 12 个头的 query 向量表示 mixed_query_layer = self.query(hidden_states) #############################第1步结束############################## # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_layer = past_key_value[0] value_layer = past_key_value[1] attention_mask = encoder_attention_mask elif is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: #############################第2步开始############################## # 前面计算过了 query 向量,这里计算的是 12 个头的 key 和 value 向量表示 # 但是这里额外做了转置操作,目的是为了能够 query 和 key 进行矩阵运算 # key、value 经过转置之后,数据由 (1, 3, 768) 变成了 (1, 12, 3, 64) key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) #############################第2步结束############################## #############################第3步开始############################## # 这个将 query 向量也进行转置,从 (1, 3, 768) 变成 (1, 12, 3, 64) # 此时, query、key、value 都变成了 (1, 12, 3, 64) # 这个数据可以理解为: 每个头的 qkv 都是一个 (3, 64) 的向量,3 表示 token 的数量 query_layer = self.transpose_for_scores(mixed_query_layer) #############################第3步结束############################## if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. #############################第4步开始############################## # 这里开始计算注意力分数,query_layer=(1, 12, 3, 64) # key_layer.transpose(-1, -2) 转置了最后两个维度之后由 (1, 12, 3, 64) 变为 (1, 12, 64, 3) # 最终计算得到的 attention_scores 的形状为: (1, 12, 3, 64) @ (1, 12, 64, 3) = (1, 12, 3, 3) # 这个 (3, 3) 每一行表示某个 token 对其他 token 的注意力分数 # 3 行表示 3 个 token 分别对其他 token 的注意力分数 attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) #############################第4步结束############################## if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": seq_length = hidden_states.size()[1] position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key #############################第5步开始############################## # 这里对注意力分数进行了缩放 math.sqrt(self.attention_head_size) = math.sqrt(64) = 8 # 为什么要这么做? # 因为后面要将分数变成概率表示,如果分数之间差值很大,就会导致计算概率时有些值变成了 0,使得 token 无法注意力到该 token # 所以这里对分数进行缩放之后,再将其转换为注意力概率表示 attention_scores = attention_scores / math.sqrt(self.attention_head_size) #############################第5步结束############################## #############################第6步开始############################## if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) # 猛一看这里对分数加上了 attention_mask,其实该 attention_mask 已经经过了处理 # 原来我们输入的 mask 是 [1, 1, 1], 这里 attention_mask 已经变成了 [0, 0, 0], 相当于没什么变化 # 但是假设我们的 mask 是 [1, 1, 0] 的话,也就是说有一个位置不需要计算,此时 attention_mask 会被处理成 [0, 0, -10000] # 此时计算得到的 attention_scores 不需要计算的位置就是很小的负数,计算注意力概率值时,该位置相当于 0,相当于掩码了 attention_scores = attention_scores + attention_mask #############################第6步结束############################## # Normalize the attention scores to probabilities. #############################第7步开始############################## # 经过全面对分数的缩放,以及掩码的操作,这里可以正常计算注意力权重分布了 attention_probs = nn.functional.softmax(attention_scores, dim=-1) #############################第7步结束############################## # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. #############################第8步开始############################## # 这一部分就是随机掩码掉一些 token 的注意力权重,减少一些计算量,可能还能增加模型的性能 # 我的理解是: 输入 200 个 token,我们可能没必须对所有 token 计算注意力,所以随机丢弃一小部分 attention_probs = self.dropout(attention_probs) #############################第8步结束############################## # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask #############################第9步开始############################## # 这一步就比较好理解了,注意力权重 (1, 12, 3, 3)@(1, 12, 3, 64)=(1, 12, 3, 64) # 分别得到了每一个头的向量表示 context_layer = torch.matmul(attention_probs, value_layer) # 这一步将形状由 (1, 12, 3, 64) 变为 (1, 3, 12, 64) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # 下面两步再将形状由 (1, 3, 12, 64) 变为 (1, 3, 768), 此时得到了多头自注意力的计算结果向量 new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) #############################第9步结束############################## outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) if self.is_decoder: outputs = outputs + (past_key_value,) return outputs
另外,需要注意的是,假设 12 个头,3 个 token 的话,那么:
- 12 个 head,表示每个 token 都有 12 组 query、key、value
- 第一个 token 的 12 个 head 中的第 1 个 query 会关注自身以及另外 2 个 token 的 key,计算得到注意力分数,并转换为概率表示,然后乘以 3 个 token 的 value 向量,最后将这 3 个向量直接相加,得到第 1 token 第 1 个 head 的向量表示
- 将 12 个 head 的注意力向量都计算出来,拼接起来就得到了第 1 个 token 的注意力表征向量
2. BertSelfOutput
class BertSelfOutput(nn.Module): def __init__(self, config): super().__init__() # 线性变换层 self.dense = nn.Linear(config.hidden_size, config.hidden_size) # 层归一化 self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 随机丢弃层 self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) # 这里做了一个残差连接,将线性变换之后的结果+上原始输入,再进行层归一化计算 hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states
关于归一化的操作,可以看下面的链接