BiLSTM + CRF 中的 CRF 层重要的是两个函数的实现,一个是损失的计算,一个是维特比解码算法的实现。前者用于模型在训练过程中学习网络参数,后者用于预测最优的解码输出。
1. 矩阵扩展
CRF 层无论是计算损失,还是计算输出标签序列时,都需要用到两个矩阵:由 BiLSTM 输出的发射分数矩阵,以及 CRF 层的转移矩阵,该矩阵也是 CRF 层在训练时要学习的参数。CRF 参数矩阵如下图所示:

\(L_{start}\) 和 \(L_{end}\) 的标签值为 num_label 和 num_label + 1。由于任意标签不可能出现转移到 \(L_{start}\) 的情况,并且 \(L_{end}\) 标签也不可能转移其他的标签,所以把这些位置初始化为 -1000,CRF 参数矩阵中的其他值使用随机初始化即可。
BiLSTM 输出的分数矩阵 (发射分数矩阵) 如下图中左图所示,但是为了后面的计算方便,需要将该矩阵进行扩展,扩展之后的发射矩阵如下图中的右图所示,其实也就是增加 start 和 end 两个位置,并将相关位置初始化为 0 或者一个很小的值,这里很小的值我们用的是 -1000。

2. 计算损失
BiLSTM + CRF 的损失函数的目标就是真实标签序列的总分数占所有可能标签序列的比重越高越好。所以其分子为真实标签序列的分数 (分数包括发射分数、转移分数的总和),而分母则是所有可能的标签序列的分数总和,每个可能的标签序列的分数包括其发射分数和转移分数。

BiLSTM +CRF 的两项计算如下:
- 真实标签序列的分数需要:样本的真实标签序列 + CRF 层的转移参数矩阵 + BiLSTM 输出的发射矩阵。
- 所有可能标签序列分数需要:CRF 层的转移参数矩阵 + BiLSTM 输出的发射矩阵。
3. 维特比解码
维特比解码其实就是根据 BiLSTM 输出的发射矩阵和 CRF 层的转移矩阵来计算分数最高的标签序列。在计算之前,先将发射矩阵按照上面的方法进行维度扩展。假设: 我们有 2 个标签 \(L_1、L_2\)、输入有 3 个字 \(w_0、w_1、w_2\),此时我们的得到的两个矩阵分别如下:

我们可以先计算 \(w_0\) 的各个路径的得分,先定义 pre、obs 变量,如下图所示:

然后令三个矩阵相加,即可得到不同路径的分数,如下图所示:

最右侧的输出结果中,每一个元素都代表了一种可能的路径。每一列分别表示从前一个标签到当前标签所有路径,最右图左上角第一个元素代表:第一个字预测为 \(L_1\) 标签的分数 \(e_{01}\) 加上 \(e_{s1} -> e_{01}\) 的转移分数 \(t_{11}\)。
接下来,计算每一列的最大值,假设每一列的最大值计算之后如下图所示:

max 函数会返回每一列中最大的分数、以及一个索引。例如: 0 索引表示到达 \(e_{s1} – > e_{01}\) 这个路径是 4 个路径中分数最高的,再直接一些的话,就表示从 0 标签到当前标签的分数最高,先记录下来,在回溯时用来获得最优路径。而获得的 4 个最大分数用于继续根据上面的方法向下结算。

当下一轮迭代开始的时候,obs 则选择 \(w_1\) 这一行,pre 则为上一步 max 计算出的 4 个分数值,转移矩阵仍为原来的,不发生变化。假设:我们已经按照上面的方式计算完所有路径的最大分数,就会得到 2 个矩阵,如下图所示:

回溯的时候,根据上图,先找到最优一个的最大值的索引。假设此时的最大值的索引为 2,从绿色的最后一行找到依赖的 2 索引位置的 1,再找倒数绿色第二个的索引为 1 的值 0,再找倒数绿色第三个索引为 0 的值 1,以此类推… 最后的序列为 [2, 1, 0, 1, 1],我们要的是正数的路径,将其翻转 [1, 1, 0, 1, 2],因为我们在计算的时候加了 START 和 END 两个标签,这俩不算,掐头去尾,最后得到的路径为 [1, 0, 1]。
4. 示例代码
程序输出结果:
真实标签:
tensor([1, 1, 0, 2, 0])
预测分数:
tensor([[ 0.1074, 0.5337, -0.7819],
[ 0.8806, 0.5112, -0.3205],
[ 0.5401, -2.2218, -0.8034],
[ 0.6645, 0.6061, -1.4834],
[ 1.2066, 0.1034, 0.3215]])
--------------------------------------------------
loss: tensor(6.0321, grad_fn=<SubBackward0>)
tags: [3, 2, 1, 0, 0]
示例代码:
class CRF(nn.Module):
def __init__(self, label_num):
super(CRF, self).__init__()
# 转移矩阵的标签数量
self.label_num = label_num
# [TAG1, TAG2, TAG3...STAR, END]
params = torch.randn(self.label_num + 2, self.label_num + 2)
self.transition_scores = nn.Parameter(params)
# 开始和结束标签
START_TAG, ENG_TAG = self.label_num, self.label_num + 1
self.transition_scores.data[:, START_TAG] = -1000
self.transition_scores.data[ENG_TAG, :] = -1000
# 定义一个较小值用于扩展发射和转移矩阵时填充
self.fill_value = -1000.0
def _log_sum_exp(self, score):
max_score, _ = torch.max(score, dim=0)
max_score_expand = max_score.expand(score.shape)
return max_score + torch.log(torch.sum(torch.exp(score - max_score_expand), dim=0))
def _get_real_path_score(self, emission_score, sequence_label):
# 计算标签的数量
seq_length = len(sequence_label)
# 计算真实路径发射分数
real_emission_score = torch.sum(emission_score[list(range(seq_length)), sequence_label])
# 在真实标签序列前后增加一个 start 和 end
b_id = torch.tensor([self.label_num], dtype=torch.int32, device=device)
e_id = torch.tensor([self.label_num + 1], dtype=torch.int32, device=device)
sequence_label_expand = torch.cat([b_id, sequence_label, e_id])
# 计算真实路径转移分数
pre_tag = sequence_label_expand[list(range(seq_length + 1))]
now_tag = sequence_label_expand[list(range(1, seq_length + 2))]
real_transition_score = torch.sum(self.transition_scores[pre_tag, now_tag])
# 计算真实路径分数
real_path_score = real_emission_score + real_transition_score
return real_path_score
def _expand_emission_matrix(self, emission_score):
# 计算标签的数量
sequence_length = emission_score.shape[0]
# 扩展时会增加 START 和 END 标签,定义该标签的值
b_s = torch.tensor([[self.fill_value] * self.label_num + [0, self.fill_value]], device=device)
e_s = torch.tensor([[self.fill_value] * self.label_num + [self.fill_value, 0]], device=device)
# 扩展发射矩阵为 (self.label_num + 2, self.label_num + 2)
expand_matrix = self.fill_value * torch.ones([sequence_length, 2], dtype=torch.float32, device=device)
emission_score_expand = torch.cat([emission_score, expand_matrix], dim=1)
emission_score_expand = torch.cat([b_s, emission_score_expand, e_s], dim=0)
return emission_score_expand
def _get_total_path_score(self, emission_score):
# 扩展发射分数矩阵
emission_score_expand = self._expand_emission_matrix(emission_score)
# 计算所有路径分数
pre = emission_score_expand[0]
for obs in emission_score_expand[1:]:
# 扩展 pre 维度
pre_expand = pre.reshape(-1, 1).expand([self.label_num + 2, self.label_num + 2])
# 扩展 obs 维度
obs_expand = obs.expand([self.label_num + 2, self.label_num + 2])
# 扩展之后 obs pre 和 self.transition_scores 维度相同
score = obs_expand + pre_expand + self.transition_scores
# 计算对数分数
pre = self._log_sum_exp(score)
return self._log_sum_exp(pre)
def forward(self, emission_scores, sequence_labels):
total_loss = 0.0
for emission_score, sequence_label in zip(emission_scores, sequence_labels):
# 计算真实路径得分
real_path_score = self._get_real_path_score(emission_score, sequence_label)
# 计算所有路径分数
total_path_score = self._get_total_path_score(emission_score)
# 最终损失
finish_loss = total_path_score - real_path_score
total_loss += finish_loss
return total_loss
def predict(self, emission_score):
"""使用维特比算法,结合发射矩阵+转移矩阵计算最优路径"""
# 扩展发射分数矩阵
emission_score_expand = self._expand_emission_matrix(emission_score)
# 计算分数
ids = torch.zeros(1, self.label_num + 2, dtype=torch.long, device=device)
val = torch.zeros(1, self.label_num + 2, device=device)
pre = emission_score_expand[0]
for obs in emission_score_expand[1:]:
# 扩展 pre 维度
pre_expand = pre.reshape(-1, 1).expand([self.label_num + 2, self.label_num + 2])
# 扩展 obs 维度
obs_expand = obs.expand([self.label_num + 2, self.label_num + 2])
# 扩展之后 obs pre 和 self.transition_scores 维度相同
score = obs_expand + pre_expand + self.transition_scores
# 获得当前多分支中最大值的分支索引
value, index = score.max(dim=0)
ids = torch.cat([ids, index.unsqueeze(0)], dim=0)
val = torch.cat([val, value.unsqueeze(0)], dim=0)
# 计算分数
pre = value
# 先取出最后一个的最大值
index = torch.argmax(val[-1])
best_path = [index]
# 再回溯前一个最大值
# 由于为了方便拼接,我们在第一个位置默认填充了0
for i in reversed(ids[1:]):
# 获得分数最大的索引
# index = torch.argmax(v)
# 获得索引对应的标签ID
index = i[index].item()
best_path.append(index)
best_path = best_path[::-1][1:-1]
return best_path

冀公网安备13050302001966号