跳转至

基于 EM 算法的多元高斯混合模型聚类及其 Python 实现

基于 EM 算法,推导多元高斯混合模型聚类的参数迭代公式,并使用 Python 对数据集进行聚类和各类别的参数求解。

在编写代码的过程中,遇到了一个非常简单但一直没发现的 Bug。

定义数组用all_density = np.array([0]*K),再用all_density[k] = k_density并不会让all_density的第k个元素改变。这是因为all_density是介于 0 到 1 之间的,而在定义all_density的时候没有指定数组内部的数据类型,默认是不支持小数的,因此赋值之后all_density的第k个元素仍然是 0。

解决方法:定义数组的时候一定要指定元素的数据类型,指定为dtype=flout64就可以存储高精度的浮点数。

result

问题描述

problem-statement

EM 算法的基本思想和具体步骤

基本思想

idea

具体步骤

steps

关于初始值和终止条件的几点说明

comments

推导多元高斯混合模型聚类的参数迭代公式

math-1

math-2

math-3

math-4

math-5

math-6

使用 Python 对数据集进行聚类和各类别的参数求解

数据集描述

programming-problem

dataset

导入包

Python
# 导入包
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
# 显示中文
plt.rcParams['font.sans-serif'] = ['SimSun']
plt.rcParams['axes.unicode_minus'] = False
# 渲染公式
plt.rcParams['text.usetex'] = True

读取数据

Python
# 读取数据,跳过第一列
data = pd.read_csv('EM.csv', index_col=0).values
# 识别数据维度
n, p = data.shape
# 可能来自的分布数量
K = 2

给出所有参数的初始值

Python
# 初始化每个数据点来自哪个分布的概率
pai = np.array([0.5]*K, dtype=np.float64)
# 初始化均值矩阵(p*K)。一个分布的均值都是 1,另一个分布的均值都是 0
mu = np.array([[1]*p, [0]*p ], dtype=np.float64).T
# 初始化协方差矩阵(p*p*K)。每个分布的协方差都是单位矩阵
sigma = np.array([np.eye(p)]*K, dtype=np.float64)
# 打印初始值
print('pai:\n{}\n\nmu:\n{}\n\nsigma:\n{}'.format(pai, mu, sigma))
Text Only
## pai:
## [0.5 0.5]
##
## mu:
## [[1. 0.]
##  [1. 0.]
##  [1. 0.]
##  [1. 0.]
##  [1. 0.]
##  [1. 0.]
##  [1. 0.]
##  [1. 0.]
##  [1. 0.]
##  [1. 0.]]
##
## sigma:
## [[[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
##   [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
##   [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
##   [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
##   [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
##   [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
##   [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
##   [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
##   [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
##   [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
##
##  [[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
##   [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
##   [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
##   [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
##   [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
##   [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
##   [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
##   [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
##   [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
##   [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]]

迭代更新参数

Python
# 求高斯分布的概率密度函数
def gaussian(x, mu, sigma):
    return 1/np.sqrt((2*np.pi)**p*np.linalg.det(sigma))*np.exp(-0.5*(x-mu).T.dot(np.linalg.inv(sigma)).dot(x-mu))
# 求给定参数下,数据点属于各分布的概率
def prob(x, mu, sigma, pai, K):
    all_density = np.array([0]*K, dtype=np.float64)
    for k in range(K):
        k_density = pai[k]*gaussian(x, mu[:,k], sigma[k])
        all_density[k] = k_density
    return all_density/all_density.sum()
Python
all_iteration_pai = [pai]
all_iteration_mu = [mu]
all_iteration_sigma = [sigma]
iteration_count = 0
while True:
    iteration_count += 1
    # 将所有样本的属于各分布的概率初始化为 0
    n_prob = np.array([[0]*K]*n, dtype=np.float64)
    # 求所有样本的属于各分布的概率
    for i in range(n):
        # 求第 i 个样本属于各分布的概率
        x = data[i,:].T
        prob_i = prob(x, mu, sigma, pai, K)
        n_prob[i] = prob_i
    # 更新 pai_k
    pai = n_prob.sum(axis=0)/n
    # 将本轮迭代的 pai 保存
    all_iteration_pai.append(pai)
    # 更新 mu
    mu = data.T.dot(n_prob)/n_prob.sum(axis=0)
    # 将本轮迭代的 mu 保存
    all_iteration_mu.append(mu)
    # 更新 sigma
    for k in range(K):
        sigma_k = np.zeros((p,p), dtype=np.float64)
        for i in range(n):
            x = data[i,:].T # x 是列向量
            sigma_k += n_prob[i,k]*np.outer(x-mu[:, k], x-mu[:, k])
        sigma_k /= n_prob.sum(axis=0)[k]
        # sigma_k 是 10*10 的矩阵
        sigma[k] = sigma_k
    # 将本轮迭代的 sigma 保存
    all_iteration_sigma.append(sigma)
    # 判断是否达到终止 1
    new_mu = all_iteration_mu[-1]
    old_mu = all_iteration_mu[-2]
    # 当本次迭代得到的 mu 和上次迭代得到的 mu 几乎没有差别时,可以终止迭代
    if (np.subtract(new_mu, old_mu)**2).sum()<1e-8:
        break

打印出迭代的次数和参数的值

Python
print('迭代次数:{}'.format(iteration_count))
Text Only
## 迭代次数:17
Python
print('pai:\n{}\n'.format(pai))
Text Only
## pai:
## [0.58870318 0.41129682]
Python
print('mu:\n{}\n'.format(mu))
Text Only
## mu:
## [[ 1.01500407  0.18662592]
##  [ 0.97982009  0.01332565]
##  [ 1.00858025 -0.12719368]
##  [ 1.0234199  -0.02555446]
##  [ 1.04411481  0.00185229]
##  [ 1.02298727  0.00717931]
##  [ 1.01387045 -0.05989538]
##  [ 0.97264306 -0.08819002]
##  [ 0.99824314  0.22147017]
##  [ 0.92537071 -0.00442322]]
Python
print('sigma:\n{}'.format(sigma))
Text Only
## sigma:
## [[[ 4.77850968e-01  2.03033657e-02  3.78546195e-02 -1.01181913e-03
##    -2.96115297e-02 -5.54141629e-03 -1.47452695e-02  1.37669547e-02
##    -6.28394728e-02 -1.21109511e-02]
##   [ 2.03033657e-02  4.77796696e-01 -6.62293811e-03  3.75985666e-02
##    -2.04491793e-02 -2.98300772e-03 -8.21489696e-03 -3.22317835e-02
##    -5.36574177e-03  4.54354716e-02]
##   [ 3.78546195e-02 -6.62293811e-03  4.54636109e-01 -3.51923568e-02
##     6.34667740e-03 -3.56763317e-02 -1.53553017e-03 -1.43510168e-02
##    -2.67175562e-02  2.75253474e-02]
##   [-1.01181913e-03  3.75985666e-02 -3.51923568e-02  4.54814674e-01
##    -1.09588387e-02 -8.04964137e-03 -6.79355591e-02 -1.58610602e-02
##     1.38411070e-03  2.97408576e-02]
##   [-2.96115297e-02 -2.04491793e-02  6.34667740e-03 -1.09588387e-02
##     4.18149094e-01  5.14822421e-02 -2.70075163e-02 -4.85719583e-02
##    -3.90824938e-02 -9.80214724e-03]
##   [-5.54141629e-03 -2.98300772e-03 -3.56763317e-02 -8.04964137e-03
##     5.14822421e-02  5.57272122e-01 -3.35251541e-03  1.08618076e-02
##    -1.70210877e-02 -3.56956117e-02]
##   [-1.47452695e-02 -8.21489696e-03 -1.53553017e-03 -6.79355591e-02
##    -2.70075163e-02 -3.35251541e-03  4.87019479e-01 -3.51035789e-02
##    -1.15753668e-02 -1.56505402e-02]
##   [ 1.37669547e-02 -3.22317835e-02 -1.43510168e-02 -1.58610602e-02
##    -4.85719583e-02  1.08618076e-02 -3.51035789e-02  4.23483429e-01
##     4.98114718e-02  2.34207606e-03]
##   [-6.28394728e-02 -5.36574177e-03 -2.67175562e-02  1.38411070e-03
##    -3.90824938e-02 -1.70210877e-02 -1.15753668e-02  4.98114718e-02
##     4.82600350e-01 -1.92214963e-02]
##   [-1.21109511e-02  4.54354716e-02  2.75253474e-02  2.97408576e-02
##    -9.80214724e-03 -3.56956117e-02 -1.56505402e-02  2.34207606e-03
##    -1.92214963e-02  5.65086238e-01]]
##
##  [[ 9.29295093e-01  4.08010510e-02 -8.23808935e-04  8.44863621e-03
##     1.20218799e-01 -1.11470574e-01  1.53808551e-02  7.88913749e-02
##    -2.78554252e-02  3.22006682e-02]
##   [ 4.08010510e-02  9.83440827e-01 -5.49377007e-02 -2.71323305e-02
##    -5.45668743e-02  1.21946570e-01  2.00130257e-02  9.74783909e-02
##    -1.53586564e-01  1.30733179e-01]
##   [-8.23808935e-04 -5.49377007e-02  1.07523742e+00  7.43292761e-02
##     3.62964590e-02 -2.81937066e-02  1.96732798e-02 -1.07758450e-01
##     8.91156366e-02 -1.92974649e-02]
##   [ 8.44863621e-03 -2.71323305e-02  7.43292761e-02  9.11091322e-01
##     1.89237737e-02 -7.29374832e-02 -1.32674621e-01 -7.57571711e-02
##     5.28107636e-02 -1.12753582e-01]
##   [ 1.20218799e-01 -5.45668743e-02  3.62964590e-02  1.89237737e-02
##     8.88748718e-01 -9.57949898e-02 -5.23966099e-02 -9.07784608e-02
##     3.94038498e-02 -7.42123835e-02]
##   [-1.11470574e-01  1.21946570e-01 -2.81937066e-02 -7.29374832e-02
##    -9.57949898e-02  1.00163225e+00  2.84507420e-02  1.29356885e-01
##     8.54158988e-02  6.04867236e-02]
##   [ 1.53808551e-02  2.00130257e-02  1.96732798e-02 -1.32674621e-01
##    -5.23966099e-02  2.84507420e-02  9.74886562e-01 -1.91740598e-02
##    -1.45956772e-01  4.74506001e-02]
##   [ 7.88913749e-02  9.74783909e-02 -1.07758450e-01 -7.57571711e-02
##    -9.07784608e-02  1.29356885e-01 -1.91740598e-02  1.07694992e+00
##    -1.53518203e-01  1.93095744e-01]
##   [-2.78554252e-02 -1.53586564e-01  8.91156366e-02  5.28107636e-02
##     3.94038498e-02  8.54158988e-02 -1.45956772e-01 -1.53518203e-01
##     1.14354842e+00 -4.74273894e-02]
##   [ 3.22006682e-02  1.30733179e-01 -1.92974649e-02 -1.12753582e-01
##    -7.42123835e-02  6.04867236e-02  4.74506001e-02  1.93095744e-01
##    -4.74273894e-02  1.00534259e+00]]]

绘制各参数的迭代变化图

Python
fig = plt.figure(figsize=(9, 16))
# 画 pai 的迭代过程
ax1 = fig.add_subplot(311)
# 将横坐标刻度设置为整数,字号设置为 18
ax1.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax1.tick_params(labelsize=18)
ax1.set_ylabel(r'$\pi$', size=18)
ax1.plot(range(iteration_count+1), [all_iteration_pai[i][0] for i in range(iteration_count+1)], label=r'$\pi_1$')
ax1.plot(range(iteration_count+1), [all_iteration_pai[i][1] for i in range(iteration_count+1)], label=r'$\pi_2$')
# 显示图例,字号为 18,且图例显示在图外
ax1.legend(fontsize=18, bbox_to_anchor=(1, 1))
# 画第一个分布的 mu 的迭代过程
ax2 = fig.add_subplot(312)
# 将横坐标刻度设置为整数,字号设置为 18
ax2.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax2.tick_params(labelsize=18)
ax2.set_ylabel(r'$\mu_1$', size=18)
for param in range(K*p):
    # 只画第一个分布的 mu
    if param%2==0:
        ax2.plot(range(iteration_count+1), [all_iteration_mu[i][param//2][0] for i in range(iteration_count+1)], label=r'$\mu_{1{%s}}$' % (param//2+1))
# 显示图例,字号为 18
ax2.legend(fontsize=18, bbox_to_anchor=(1, 1))
# 画第二个分布的 mu 的迭代过程
ax3 = fig.add_subplot(313)
# 将横坐标刻度设置为整数,字号设置为 18
ax3.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax3.tick_params(labelsize=18)
ax3.set_xlabel('迭代次数', size=18, usetex=False) # 中文不用 tex 渲染,否则会报错
ax3.set_ylabel(r'$\mu_2$', size=18)
for param in range(K*p):
    # 只画第二个分布的 mu
    if param%2==1:
        ax3.plot(range(iteration_count+1), [all_iteration_mu[i][param//2][1] for i in range(iteration_count+1)], label=r'$\mu_{2{%s}}$' % (param//2+1))
# 显示图例,字号为 18
ax3.legend(fontsize=18, bbox_to_anchor=(1, 1))
# 标题
fig.suptitle('EM 算法迭代过程', size=20, usetex=False) # 中文不用 tex 渲染,否则会报错
# 标题离图像的距离
fig.subplots_adjust(top=0.95)

result

评论