期望最大化算法(Expectation Maximization,EM) 是一种基于不完整、包含隐变量观测数据进行统计模型参数估计的方法。
我们知道,统计模型中的参数都需要根据观测数据集(训练数据)来进行估计。但是,在有些场景下,观测数据集中包含的信息不完整,有缺失,此时就不太容易去估计相应的参数。EM 算法就是针对这种问题的的方法。
1. 问题场景
桌子上放着一个盒子,其中有两种类型的硬币 A 和 B。随机从盒子中抓取硬币是 A 的概率为 π, 硬币 A 和 B 正面朝上的概率分别是 p1 和 p2。接下来,估计 p1、p2、π 的值。
对于这个问题:p1、p2、π 就是参数值。我们看看在完整观测数据下、以及不完整观测数据下的问题如何解决?
假设得到的完整的观测数据如下表,1 表示正面朝上,0 表示反面朝上。
数据 | 硬币 |
1 | A |
0 | A |
1 | B |
1 | A |
0 | B |
估计参数:
- π = 3/5
- p1 = 2/3
- p2 = 1/2
数据 | 硬币 A 概率 | 硬币 B 概率 |
1 | 0.6 | 0.4 |
0 | 0.4 | 0.6 |
1 | 0.6 | 0.4 |
1 | 0.6 | 0.4 |
0 | 0.4 | 0.6 |
估计参数:
- π = (0.6 + 0.4 + 0.6 + 0.6 + 0.4) / 5 = 0.52
- p1 = (0.6 * 1 + 0.4 * 0 + 0.6 * 1 + 0.6 * 1 + 0.4 * 0 ) / (0.6 + 0.4 + 0.6 + 0.6 + 0.4) = 0.6923
- p2 = (0.4 * 1 + 0.6 * 0 + 0.4 * 1 + 0.4 * 1 + 0.6 * 0) / (0.4 + 0.6 + 0.4 + 0.4 + 0.6) = 0.5000
假设:每次不知道抛掷的是哪个硬币、也不知道可能硬币的概率,如下图所示:
数据 | 硬币 | 硬币 A 概率 | 硬币 B 概率 |
1 | ? | ? | ? |
0 | ? | ? | ? |
1 | ? | ? | ? |
1 | ? | ? | ? |
0 | ? | ? | ? |
我们需要在包含缺失信息、隐藏变量的观测数据中去估计 p1、p2、pi 这三个参数。这样的问题有个特点:
- 想估计参数值,必须知道隐藏变量的信息
- 想知道隐藏变量的信息,必须知道参数值
2. 算法思路
对于这样的问题,我们可以假设一组参数,并在此参数的基础上,根据观测数据去估计隐藏变量的分布信息,然后再根据新的隐藏变量信息反过来去更新参数。这个过程可以反复进行,但是什么时候停止这个循环的过程?
我们可以设定一个条件,观测数据的对数似然值最大。即:当 p1、p2、π 是什么值的时候,这组观测数据出现的可性能最大。例如:
数据 | 硬币 A | 硬币 B |
1 | π * p1 | (1 – π ) * p2 |
0 | π * (1 – p1) | (1 – π ) * (1 – p2) |
1 | π * p1 | (1 – π ) * p2 |
1 | π * p1 | (1 – π ) * p2 |
0 | π * (1 – p1) | (1 – π ) * (1 – p2) |
每个样本的对数似然值计算公式:
上面提到的解决思路就是 EM 算法的思想:
- 根据一定的策略对参数进行初始化(当然不准)
- E 步(Expectation Step):基于当前参数估计值,计算每个观测数据属于不同类别的概率;
- M 步(Maximization Step):基于 E 步中的计算结果,更新参数值,最大化对数似然函数。
E 步:
数据 | 硬币 A 概率 | 硬币 B 概率 |
1 | \( \gamma_{0}^{A} = \frac{π * p1}{π * p1 + (1 – π ) * p2} \) | \( \gamma_{0}^{B} = 1- \gamma_{0}^{A}\) |
0 | \( \gamma_{1}^{A} = \frac{π * (1 – p1)}{π * (1 – p1) + (1 – π ) * (1 – p2)} \) | \( \gamma_{1}^{B} = 1- \gamma_{1}^{A}\) |
1 | … | … |
1 | … | … |
0 | … | … |
M 步:
EM 重复执行 E 步和 M 步,直至收敛:
- 似然函数的收敛:EM算法每次迭代会使得似然函数的值逐步增加,一个常见的收敛条件是似然函数的值在两次迭代间的变化小于某个预设的阈值;
- 参数的收敛:参数的变化小于某个阈值也是常用的收敛条件之一,即当两次迭代间模型参数 θ 的变化量小于给定的阈值时,算法认为已经收敛;
- 最大迭代次数:当达到最大迭代次数时即停止。
注意:EM 算法可能收敛到局部最优解,对初始值敏感,且收敛速度较慢。因此,实际应用时,需要进行一些改进来提升其性能。
接下来,我们来观察下上面提到问题的计算过程:
import numpy as np def e_step(coins, pi, p1, p2): resp = [] for coin in coins: a = pi * p1 ** coin * (1 - p1) ** (1 - coin) b = (1 - pi) * p2 ** coin * (1 - p2) ** (1 - coin) ap = a / (a + b) bp = b / (a + b) resp.append((ap, bp)) return np.array(resp) def m_step(coins, resp): pi = sum(resp[:, 0]) / len(resp[:, 0]) p1 = np.dot(coins, resp[:, 0]) / sum(resp[:, 0]) p2 = np.dot(coins, resp[:, 1]) / sum(resp[:, 1]) return pi, p1, p2 def em(coins): # 初始化参数 pi, p1, p2 = 0.5, 0.6, 0.4 # 迭代计算 for _ in range(3): # 打印参数 print('pi: %.5f p1: %.5f p2: %.5f' % (pi, p1, p2)) resp = e_step(coins, pi, p1, p2) # 输出概率 print(resp) print('-' * 30) pi, p1, p2 = m_step(coins, resp) if __name__ == '__main__': coins = [1, 0, 1, 1, 0] em(coins)
我们发现,迭代一次之后参数就不再发生变化,这是因为我们的问题、观测数据都过于简单,在较为复杂的问题中,需要迭代更多的步骤才能收敛。