我正在学习 GMM 进行颜色分割。我在网上找到了一个很好的资源,其中包含以下 GMM 代码:
import matplotlib.pyplot as plt
from matplotlib import style
style.use('fivethirtyeight')
import numpy as np
from scipy.stats import norm
np.random.seed(0)
X = np.linspace(-5, 5, num=20)
X0 = X * np.random.rand(len(X)) + 15 # Create data cluster 1
X1 = X * np.random.rand(len(X)) - 15 # Create data cluster 2
X2 = X * np.random.rand(len(X)) # Create data cluster 3
X_tot = np.stack((X0, X1, X2)).flatten() # Combine the clusters to get the random datapoints from above
class GM1D:
def __init__(self, X, iterations):
self.iterations = iterations
self.X = X
self.mu = None
self.pi = None
self.var = None
def run(self):
self.mu = [-8, 8, 5]
self.pi = [1 / 3, 1 / 3, 1 / 3]
self.var = [5, 3, 1]
for iter in range(self.iterations):
r = np.zeros((len(X_tot), 3))
for c, g, p in zip(range(3), [norm(loc=self.mu[0], scale=self.var[0]),
norm(loc=self.mu[1], scale=self.var[1]),
norm(loc=self.mu[2], scale=self.var[2])], self.pi):
r[:, c] = p * g.pdf(X_tot) # Write the probability that x belongs to gaussian c in column c.
for i in range(len(r)):
r[i] = r[i] / (np.sum(self.pi) * np.sum(r, axis=1)[i])
fig = plt.figure(figsize=(10, 10))
ax0 = fig.add_subplot(111)
for i in range(len(r)):
ax0.scatter(self.X[i], 0, c=np.array([r[i][0], r[i][1], r[i][2]]), s=100)
for g, c in zip([norm(loc=self.mu[0], scale=self.var[0]).pdf(np.linspace(-20, 20, num=60)),
norm(loc=self.mu[1], scale=self.var[1]).pdf(np.linspace(-20, 20, num=60)),
norm(loc=self.mu[2], scale=self.var[2]).pdf(np.linspace(-20, 20, num=60))], ['r', 'g', 'b']):
ax0.plot(np.linspace(-20, 20, num=60), g, c=c)
m_c = []
for c in range(len(r[0])):
m = np.sum(r[:, c])
m_c.append(m) # For each cluster c, calculate the m_c and add it to the list m_c
for k in range(len(m_c)):
self.pi[k] = (m_c[k] / np.sum(m_c)) # For each cluster c, calculate the fraction of points pi_c which belongs to cluster c
self.mu = np.sum(self.X.reshape(len(self.X), 1) * r, axis=0) / m_c
var_c = []
for c in range(len(r[0])):
var_c.append((1 / m_c[c]) * np.dot(((np.array(r[:, c]).reshape(60, 1)) * (self.X.reshape(len(self.X), 1) - self.mu[c])).T, (self.X.reshape(len(self.X), 1) - self.mu[c])))
plt.show()
GM1D = GM1D(X_tot, 10)
GM1D.run()
现在我的理解是,在 EM 的最大化步骤中,我们必须更新高斯参数(协方差矩阵、高斯的均值和大小(pi_c))在上面给出的代码中,我可以看到 pi_c 和均值(mu)的值正在更新,但我认为协方差矩阵的值没有更新。但是,当我运行代码时,它似乎正在工作(?)。有人可以帮我确定代码是否正确。代码来自以下资源