Even if you set the initial number of classes to 6, you can see that it finally converges to 3 classes
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.special import digamma
import matplotlib.cm as cm
plt.style.use("ggplot")
#dimension
D = 2
#The number of data
N = 2000
#Actual value
mu1 = [0, 1]
sigma1 = 0.2 * np.eye(D)
N1 = int(N*0.3)
mu2 = [-1, -1]
sigma2 = 0.1 * np.eye(D)
N2 = int(N*0.5)
mu3 = [1, -1]
sigma3 = 0.1 * np.eye(D)
N3 = int(N*0.2)
plt.figure(figsize=(5, 5))
data = np.concatenate([np.random.multivariate_normal(mu1, sigma1, N1),
np.random.multivariate_normal(mu2, sigma2, N2),
np.random.multivariate_normal(mu3, sigma3, N3)
])
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
plt.scatter(data[:, 0], data[:, 1], s=10)
plt.show()
#Initial number of classes
K = 6
#initial value
mu = np.array([[0., -0.5],[0., 0.5], [1., 0.5], [-1., -0.5], [-1, -1.5], [1, 1.5]])
S = np.array([0.1 * np.eye(2) for k in range(K)])
#Prior distribution parameters
alpha_0 = 1e-3
beta_0 = 1e-3
m_0 = np.zeros((K, D))
nu_0 = 1
W_0 = np.eye(D)
#Initial parameters
W_k = np.zeros((K, D, D))
E_mu_lam = np.zeros((N, K))
def multi_gauss(x, y, mu, sigma):
return stats.multivariate_normal(mu, sigma).pdf(np.array([x, y]))
#1: r_initialization of nk
r = np.ones([N, K]) / K
pi = np.ones(K) / K
g = np.zeros((N, K))
for k in range(K):
g[:, k] = np.vectorize(lambda x, y: pi[k] * multi_gauss(x, y, mu[k], S[k]))(data[:, 0], data[:, 1])
for k in range(K):
r[:, k] = g[:, k] / g.sum(1)
#Illustrated
X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
cmap_colors = [cm.spring, cm.summer, cm.autumn, cm.winter, cm.Reds_r, cm.Dark2]
colors = ["pink", "green", "orange", "blue", "red", "black"]
plt.figure(figsize=(5, 5))
for k in range(K):
Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)
plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), alpha=0.3, s=10)
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
init_title = "iter: 0"
plt.title(init_title)
#plt.savefig("data/" + init_title + ".png ")
plt.show()
for i in range(20):
#2:Calculate three statistics
N_k = r.sum(0)
mu = r.T.dot(data) / np.c_[N_k]
for k in range(K):
S[k] = (np.c_[r[:, k]] * (data - mu[k])).T.dot(data - mu[k]) / N_k[k]
#3: Mstep
alpha = alpha_0 + N_k
beta = beta_0 + N_k
m_k = (beta_0 * m_0 + np.c_[N_k] * mu) / np.c_[beta]
for k in range(K):
tmp1 = beta_0 * N_k[k] * np.outer(mu[k] - m_0[k], mu[k] - m_0[k]) / (beta_0 + N_k[k])
tmp2 = LA.inv(W_0) + N_k[k] * S[k] + tmp1
W_k[k] = LA.inv(tmp2)
nu_k = nu_0 + N_k
#4: Estep
E_ln_lam = digamma(nu_k / 2) + digamma((nu_k - 1) / 2) + D * np.log(2) + np.log([LA.norm(w) for w in W_k])
E_ln_pi = digamma(alpha) - digamma(alpha.sum())
for k in range(K):
E_mu_lam[:, k] = D / beta[k] + nu_k [k] * np.diag((data - m_k[k]).dot(W_k[k]).dot((data - m_k[k]).T))
ro = np.exp(E_ln_pi + E_ln_lam / 2. - D * np.log(2 * np.pi) / 2. - E_mu_lam / 2.)
r = ro / np.c_[ro.sum(1)]
r[r < 1e-10] = 1e-10
#Create gif diagram
plt.figure(figsize=(5, 5))
X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
pi = np.exp(E_ln_pi)
for k in range(K):
Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
if np.exp(E_ln_pi)[k] > 0.01:
plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)
plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), s=10, alpha=0.3)
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
title = "iter: {}".format(i+1)
plt.title(title)
#plt.savefig("data/" + title + ".png ")
plt.show()
Recommended Posts