高斯混合模型(GMM)

高斯混合模型(Gaussian Mixture Model, GMM)通过多个高斯分布的加权来描述一个随机变量的概率分布,它的公式表示如下:

  1. K 表示高斯分布的数量
  2. wi 表示将多个高斯分布混合在一起的混合系数,其和为1
  3. μi 表示每个高斯分布的均值
  4. i 表示每个高斯分布的协方差矩阵

上述公式中,每个高斯分布也叫做一个成分(component), 高斯混合模型的概率分布如下图所示:

图像生成的代码为:

import numpy as np
import matplotlib.pyplot as plt
import math


def gaussian_distribution(x, mean=0, std=1):
    return np.exp(-(x - mean) ** 2 / (2 * std ** 2)) / (math.sqrt(2 * math.pi) * std)


def test():

    x = np.linspace(-15, 15, 1000)
    y1 = gaussian_distribution(x, mean=-3.0, std=1.5)
    y2 = gaussian_distribution(x, mean=+3.0, std=2.0)
    y3 = gaussian_distribution(x, mean=+6.0, std=1.5)
    y4 = 0.3*y1 + 0.2*y2 + 0.5*y3

    plt.figure(figsize=(12, 8), dpi=80)
    plt.plot(x, y1, color='r', label='均值=-3.0, 标准差=1.5', alpha=0.3)
    plt.plot(x, y2, color='g', label='均值=+3.0, 标准差=2.0', alpha=0.3)
    plt.plot(x, y3, color='c', label='均值=+6.0, 标准差=1.5', alpha=0.3)
    plt.plot(x, y4, color='k', label='高斯分布加权线性组合')

    plt.grid(True)
    plt.legend(loc='best')
    plt.show()


if __name__ == '__main__':
    test()

Kmeans 算法是基于质心的聚类算法,它将样本归属到欧式距离最近的簇。高斯混合模型是将每个样本归属到概率最高的高斯分布,每个高斯分布可以理解为一个簇。

假设我们有一个样本,并且有 3 个高斯分布,要判断其归属为那个高斯分布,当我们知道每一个高斯分布的参数 μ 和 ∑ 时,就可以判断出该样本归属为每个高斯分布的 w,我们将样本归属到 w 最大对应的高斯分布就得到了样本的所属的簇。

高斯混合模型未知的参数有: 每个高斯分布的权重 w、期望 μ、协方差 ∑。如何估计这些参数就是高斯混合模型的学习过程。

上面的给出的高斯混合模型的概率分布公式可以计算出某个样本的概率,假设我们共有 N 个样本,这些样本的联合概率分布表示为下面的公式:

上面的公司就是采用极大似然估计法得到的公式,我们一般将其转换为对数函数,如下所示:

如果上面的公式只有一个高斯分布模型的话,使用极大似然估计估计模型参数 μ、∑。但是上面的是由多个高斯模型组合在一起的,如果要估计参数的话,我们得先知道训练样本中哪些样本属于第一个高斯分布、哪些样本属于第二个高斯分布…,比如:知道了前100个样本属于第一个分布、接下来 100 个样本属于第二个分布 … 以此类推,那我们对每个高斯分布使用极大似然估计就可以得到模型参数了。

也就是说,我们的求解过程中是包含隐变量的,而这个隐变量就是训练样本所属的类别。这时,我们马上就想到了对于包含隐变量的参数估计可以使用 EM 算法。

公式 EM 推导部分暂时省略,后面有时间再补充,此时得到高斯混合模型的算法流程如下:

初始化高斯混合模型中每个高斯分布模型的参数:权重 w、期望 μ、协方差 ∑。

使用下列公式计算第i个样本来自第j个高斯分布的概率

更新每个高斯分布权重的公式是:

更新每个高斯分布均值的公式是:

更新每个高斯分布协方差矩阵的公式是:

接下来,使用 scikit-learn 中的混合高斯模型 API 进行聚类:

import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.datasets import make_blobs
import numpy as np


def test():

    # 固定随机数种子
    np.random.seed(0)

    x, y = make_blobs(n_samples=500,
                      n_features=2,
                      centers=[(0, 1.5), [1, 0.5], [2, 2]],
                      cluster_std=[0.4, 0.5, 0.5],
                      random_state=0)


    # 实例化高斯混合模型, n_components 表示簇的个数
    estimator = GaussianMixture(n_components=4)
    estimator.fit(x)
    y_pred = estimator.predict(x)
    plt.scatter(x[:, 0], x[:, 1], c=y_pred)
    plt.show()

    # 高斯混合模型预测概率值
    y_prob = estimator.predict_proba(x)
    # print(y_prob)

    # 打印中心位置
    print('权重:\n', estimator.weights_)
    print('均值:\n', estimator.means_)
    print('协方差:\n', estimator.covariances_)


if __name__ == '__main__':
    test()

程序输出结果:

权重:
 [0.31630872 0.13415278 0.33592098 0.21361752]
均值:
 [[ 0.9839625   0.41397728]
 [ 1.95119411  2.38778786]
 [-0.05309072  1.47808799]
 [ 1.92833325  1.72952516]]
协方差:
 [[[ 2.08624956e-01  3.39826956e-02]
  [ 3.39826956e-02  2.46894683e-01]]

 [[ 1.24763427e-01  4.58519820e-04]
  [ 4.58519820e-04  1.77589744e-01]]

 [[ 1.28429081e-01 -1.73218188e-02]
  [-1.73218188e-02  1.81657443e-01]]

 [[ 3.75371903e-01 -3.11420985e-04]
  [-3.11420985e-04  1.37567448e-01]]]

下面绘制等高线图:

import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.datasets import make_blobs
import numpy as np


def test():

    # 固定随机数种子
    np.random.seed(0)

    x, y = make_blobs(n_samples=500,
                      n_features=2,
                      centers=[(0, 1.5), [1, 0.5], [2, 2]],
                      cluster_std=[0.2, 0.2, 0.4])


    # 实例化高斯混合模型, n_components 表示簇的个数
    estimator = GaussianMixture(n_components=3, covariance_type='full')
    estimator.fit(x)

    X, Y = np.meshgrid(np.linspace(x[:, 0].min(), x[:, 0].max(), 1000),
                       np.linspace(x[:, 1].min(), x[:, 1].max(), 1000))
    data = np.c_[X.ravel(), Y.ravel()]
    y_pred = estimator.score_samples(data).reshape(X.shape)
    plt.contour(X, Y, y_pred)
    plt.scatter(x[:, 0], x[:, 1], c=y)
    plt.show()


if __name__ == '__main__':
    test()

未经允许不得转载:一亩三分地 » 高斯混合模型(GMM)