Variational Bayesian estimation of mixed Gaussian distribution

Illustrated

変分混合ガウス分布.gif

Even if you set the initial number of classes to 6, you can see that it finally converges to 3 classes

algorithm

  1. Initialize $ r_ {nk} $
  2. Calculate three statistics
    • N_k = \sum_{n=1}^N r_{nk}
    • {\bar x_k} = \frac{1}{N_k} \sum_{n=1}^N r_{nk}x_n
    • S_k = \frac{1}{N_k} \sum_{n=1}^N r_{nk}(x_n - {\bar x_k})(x_n - {\bar x_k})^T
  3. Mstep: q(\pi) = Dir(\pi|\alpha), q(\mu_k, \Lambda_k) = N(\mu_k|m_k, (\beta_k \Lambda_k)^{-1})W(\Lambda_k|W_k, \nu_k)To ask.
    • \alpha_k = \alpha_0 + N_k
    • \beta_k = \beta_0 + N_k
    • m_k = \frac{1}{\beta_k}(\beta_0 m_0 + N_k {\bar x_k})
    • W_k^{-1} = W_0^{-1} + N_k S_k + \frac{\beta_0 N_k}{\beta_0 + N_k}({\bar x_k} - m_0)({\bar x_k} - m_0)^T
    • \nu_k = \nu_0 + N_k
  4. Estep: $ q (Z) = \ Pi_ {n = 1} ^ N \ Pi_ {k = 1} ^ K r_ {nk} ^ {z_ {nk}} $
    • r_{nk} = \frac{\rho_{nk}}{\sum_{j=1}^K \rho_{nj}}
    • \ln\rho_{nk} = E[\ln\pi_k] + \frac{1}{2} E[\ln |\Lambda_k|] - \frac{D}{2}\ln(2\pi) - \frac{1}{2}E_{\mu_k, \Lambda_k}[(x_n - \mu_k)^T \Lambda_k (x_n - \mu_k)]
    • E_{\mu_k, \Lambda_k}[(x_n - \mu_k)^T \Lambda_k (x_n - \mu_k)] = D\beta_k^{-1} + \nu_k(x_n - m_k)^TW_k(x_n - m_k)
    • E[\ln|\Lambda_k|] = \sum_{i=1}^D \psi(\frac{\nu_k + 1 - i}{2}) + D \ln 2 + \ln|W_k|
    • E[\ln \pi_k] = \psi(\alpha_k) - \psi({\hat \alpha})
    • {\hat \alpha} = \sum_k \alpha_k
  5. Turn 2 ~ 4 until it converges

Implementation

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.special import digamma
import matplotlib.cm as cm
plt.style.use("ggplot")

#dimension
D = 2
#The number of data
N = 2000


#Actual value
mu1 = [0, 1]
sigma1 = 0.2 * np.eye(D)
N1 = int(N*0.3)
mu2 = [-1, -1]
sigma2 = 0.1 * np.eye(D)
N2 = int(N*0.5)
mu3 = [1, -1]
sigma3 = 0.1 * np.eye(D)
N3 = int(N*0.2)

plt.figure(figsize=(5, 5))
data = np.concatenate([np.random.multivariate_normal(mu1, sigma1, N1), 
                    np.random.multivariate_normal(mu2, sigma2, N2),
                    np.random.multivariate_normal(mu3, sigma3, N3)
                   ])
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
plt.scatter(data[:, 0], data[:, 1], s=10)
plt.show()

w8dO9h9H66IZwAAAABJRU5ErkJggg==.png

#Initial number of classes
K = 6
#initial value
mu = np.array([[0., -0.5],[0., 0.5], [1., 0.5], [-1., -0.5], [-1, -1.5], [1, 1.5]])
S = np.array([0.1 * np.eye(2) for k in range(K)])
 
#Prior distribution parameters
alpha_0 = 1e-3
beta_0 = 1e-3
m_0 = np.zeros((K, D))
nu_0 = 1
W_0 = np.eye(D)
 
#Initial parameters
W_k = np.zeros((K, D, D))
E_mu_lam = np.zeros((N, K))
 
def multi_gauss(x, y, mu, sigma):
    return stats.multivariate_normal(mu, sigma).pdf(np.array([x, y]))
 
#1: r_initialization of nk
r = np.ones([N, K]) / K
pi = np.ones(K) / K
g = np.zeros((N, K))
for k in range(K):
    g[:, k] = np.vectorize(lambda x, y: pi[k] * multi_gauss(x, y, mu[k], S[k]))(data[:, 0], data[:, 1])
for k in range(K):
    r[:, k] = g[:, k] / g.sum(1)

#Illustrated
X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
cmap_colors = [cm.spring, cm.summer, cm.autumn, cm.winter, cm.Reds_r, cm.Dark2]
colors = ["pink", "green", "orange", "blue", "red", "black"]
plt.figure(figsize=(5, 5))
for k in range(K):
    Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
    plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)
 
 
plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), alpha=0.3, s=10)
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
init_title = "iter: 0"
plt.title(init_title)
#plt.savefig("data/" + init_title + ".png ")
plt.show()

n9OnT9Pc3KzKdIpEGtT0IVbdmfJhbduhUGhT+oCIjkz3A8h8XxDVoLYXa+NDT09P2nyQK1AkEokEDY4MJRKJZDOQwVAikUiQwVAikUgAGQwlEokEkMFQIpFIABkMJRKJBJDBUCKRSAD4HwgVxm0OO01kAAAAAElFTkSuQmCC.png

for i in range(20):
    #2:Calculate three statistics
    N_k = r.sum(0)
    mu = r.T.dot(data) / np.c_[N_k]
    for k in range(K):
        S[k] = (np.c_[r[:, k]] * (data - mu[k])).T.dot(data - mu[k]) / N_k[k]
 
    #3: Mstep
    alpha = alpha_0 + N_k
    beta = beta_0 + N_k
    m_k = (beta_0 * m_0 + np.c_[N_k] * mu) / np.c_[beta]
    for k in range(K):
        tmp1 = beta_0 * N_k[k] * np.outer(mu[k] - m_0[k], mu[k] - m_0[k]) / (beta_0 + N_k[k])
        tmp2 = LA.inv(W_0) + N_k[k] * S[k] + tmp1
        W_k[k] = LA.inv(tmp2)
    nu_k = nu_0 + N_k
 
    #4: Estep
    E_ln_lam = digamma(nu_k / 2) + digamma((nu_k - 1) / 2) + D * np.log(2) + np.log([LA.norm(w) for w in W_k])
    E_ln_pi = digamma(alpha) - digamma(alpha.sum())
    for k in range(K):
        E_mu_lam[:, k] = D / beta[k] + nu_k [k] * np.diag((data - m_k[k]).dot(W_k[k]).dot((data - m_k[k]).T))
    ro = np.exp(E_ln_pi + E_ln_lam / 2. - D * np.log(2 * np.pi) / 2. - E_mu_lam / 2.)
    r = ro / np.c_[ro.sum(1)]
    r[r < 1e-10] = 1e-10
 
    #Create gif diagram
    plt.figure(figsize=(5, 5))
    X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
    pi = np.exp(E_ln_pi)
    for k in range(K):
        Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
        if np.exp(E_ln_pi)[k] > 0.01:
            plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)
    plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), s=10, alpha=0.3)
    plt.xlim(-2.1, 2.1)
    plt.ylim(-2.1, 2.1)
    title = "iter: {}".format(i+1)
    plt.title(title)
    #plt.savefig("data/" + title + ".png ")
    plt.show()

Recommended Posts

Variational Bayesian estimation of mixed Gaussian distribution
EM of mixed Gaussian distribution
Concept of Bayesian reasoning (2) ... Bayesian estimation and probability distribution
PRML Chapter 9 Mixed Gaussian Distribution Python Implementation
Estimating mixed Gaussian distribution by EM algorithm
PRML Chapter 10 Variational Gaussian Distribution Python Implementation
EM algorithm calculation for mixed Gaussian distribution problem
[Python] Implementation of clustering using a mixed Gaussian model
Python: Diagram of 2D data distribution (kernel density estimation)
Variational Bayesian inference implementation of topic model in python
Verification of normal distribution