L'algorithme EM est présenté au chapitre 9 de PRML. L'algorithme EM lui-même est une technique qui peut être utilisée à divers endroits, et j'ai moi-même combiné l'algorithme EM avec un classificateur pour effectuer une classification plus résistante au bruit. Cependant, comme je n'avais jamais groupé avec la distribution gaussienne mixte, qui est l'exemple d'application le plus célèbre de l'application de l'algorithme EM, j'ai implémenté ** l'estimation la plus probable de la distribution gaussienne mixte avec l'algorithme EM ** cette fois.
Par exemple, une distribution gaussienne normale des points de données tels que les points bleus dans la figure ci-dessus.
p({\bf x}) = \mathcal{N}({\bf x}|{\bf\mu},{\bf\Sigma})
Lors du montage avec, C'est pourquoi la distribution gaussienne à pic unique n'est pas un bon modèle dans ce cas. En se concentrant sur le fait que les points de données sont divisés en trois groupes, dans ce cas, une distribution gaussienne mixte utilisant trois distributions gaussiennes.
p({\bf x}) = \sum_{k=1}^3 \pi_k\mathcal{N}({\bf x}|{\bf\mu}_k,{\bf\Sigma}_k)
Est considéré comme un modèle approprié. Cependant, soit $ \ sum_k \ pi_k = 1 $. $ {\ Bf \ mu} \ _1 $ est le haut, $ {\ bf \ mu} \ _2 $ est le bas à droite et $ {\ bf \ mu} \ _3 $ est la moyenne des blocs de points de données en bas à gauche. Je vais. Par conséquent, ** ajuster chaque bloc avec une distribution gaussienne est bien, mais il est évident pour nous, humains, quels points de données appartiennent à quel bloc, mais les machines ne le savent pas.
Quel point de données appartient à lequel des K blocs, avec les coordonnées $ \ {{\ bf x} \ _n \} \ _ {n = 1} ^ N $ de N points de données comme variables observées Variable latente de codage 1 sur k $ \ {{\ bf z} \ _n \} \ _ {n = 1} ^ N $ et paramètre $ \ {{\ bf \ pi} \ _k , {\ bf \ mu} \ _ k, {\ bf \ Sigma} \ _ k \} \ _ {k = 1} ^ K $ sont estimés en même temps. Lorsqu'il existe des variables latentes comme celle-ci, il est courant d'utiliser l'algorithme EM.
La procédure d'estimation la plus probable de la distribution gaussienne mixte par l'algorithme EM (équations PRML (9.23) à (9.27)) est résumée dans la section 9.2.2 de PRML, elle est donc omise ici.
Utilisez Numpy et matplotlib comme d'habitude.
import matplotlib.pyplot as plt
import numpy as np
class GaussianMixture(object):
def __init__(self, n_component):
#Nombre de distributions gaussiennes
self.n_component = n_component
#Estimation la plus probable à l'aide de l'algorithme EM
def fit(self, X, iter_max=10):
#Dimension des données
self.ndim = np.size(X, 1)
#Initialisation du coefficient de mélange
self.weights = np.ones(self.n_component) / self.n_component
#Initialisation moyenne
self.means = np.random.uniform(X.min(), X.max(), (self.ndim, self.n_component))
#Initialisation de la matrice de covariance
self.covs = np.repeat(10 * np.eye(self.ndim), self.n_component).reshape(self.ndim, self.ndim, self.n_component)
#Répétez les étapes E et M
for i in xrange(iter_max):
params = np.hstack((self.weights.ravel(), self.means.ravel(), self.covs.ravel()))
#Étape E, calculez le taux de charge
resps = self.expectation(X)
#Étape M, mise à jour des paramètres
self.maximization(X, resps)
#Vérifiez si les paramètres ont convergé
if np.allclose(params, np.hstack((self.weights.ravel(), self.means.ravel(), self.covs.ravel()))):
break
else:
print("parameters may not have converged")
#Fonction gaussienne
def gauss(self, X):
precisions = np.linalg.inv(self.covs.T).T
diffs = X[:, :, None] - self.means
assert diffs.shape == (len(X), self.ndim, self.n_component)
exponents = np.sum(np.einsum('nik,ijk->njk', diffs, precisions) * diffs, axis=1)
assert exponents.shape == (len(X), self.n_component)
return np.exp(-0.5 * exponents) / np.sqrt(np.linalg.det(self.covs.T).T * (2 * np.pi) ** self.ndim)
#Étape E
def expectation(self, X):
#Formule PRML(9.23)
resps = self.weights * self.gauss(X)
resps /= resps.sum(axis=-1, keepdims=True)
return resps
#Étape M
def maximization(self, X, resps):
#Formule PRML(9.27)
Nk = np.sum(resps, axis=0)
#Formule PRML(9.26)
self.weights = Nk / len(X)
#Formule PRML(9.24)
self.means = X.T.dot(resps) / Nk
diffs = X[:, :, None] - self.means
#Formule PRML(9.25)
self.covs = np.einsum('nik,njk->ijk', diffs, diffs * np.expand_dims(resps, 1)) / Nk
#Distribution de probabilité p(x)Calculer
def predict_proba(self, X):
#Formule PRML(9.7)
gauss = self.weights * self.gauss(X)
return np.sum(gauss, axis=-1)
#Clustering
def classify(self, X):
joint_prob = self.weights * self.gauss(X)
return np.argmax(joint_prob, axis=1)
gaussian_mixture.py
import matplotlib.pyplot as plt
import numpy as np
class GaussianMixture(object):
def __init__(self, n_component):
self.n_component = n_component
def fit(self, X, iter_max=10):
self.ndim = np.size(X, 1)
self.weights = np.ones(self.n_component) / self.n_component
self.means = np.random.uniform(X.min(), X.max(), (self.ndim, self.n_component))
self.covs = np.repeat(10 * np.eye(self.ndim), self.n_component).reshape(self.ndim, self.ndim, self.n_component)
for i in xrange(iter_max):
params = np.hstack((self.weights.ravel(), self.means.ravel(), self.covs.ravel()))
resps = self.expectation(X)
self.maximization(X, resps)
if np.allclose(params, np.hstack((self.weights.ravel(), self.means.ravel(), self.covs.ravel()))):
break
else:
print("parameters may not have converged")
def gauss(self, X):
precisions = np.linalg.inv(self.covs.T).T
diffs = X[:, :, None] - self.means
assert diffs.shape == (len(X), self.ndim, self.n_component)
exponents = np.sum(np.einsum('nik,ijk->njk', diffs, precisions) * diffs, axis=1)
assert exponents.shape == (len(X), self.n_component)
return np.exp(-0.5 * exponents) / np.sqrt(np.linalg.det(self.covs.T).T * (2 * np.pi) ** self.ndim)
def expectation(self, X):
resps = self.weights * self.gauss(X)
resps /= resps.sum(axis=-1, keepdims=True)
return resps
def maximization(self, X, resps):
Nk = np.sum(resps, axis=0)
self.weights = Nk / len(X)
self.means = X.T.dot(resps) / Nk
diffs = X[:, :, None] - self.means
self.covs = np.einsum('nik,njk->ijk', diffs, diffs * np.expand_dims(resps, 1)) / Nk
def predict_proba(self, X):
gauss = self.weights * self.gauss(X)
return np.sum(gauss, axis=-1)
def classify(self, X):
joint_prob = self.weights * self.gauss(X)
return np.argmax(joint_prob, axis=1)
def create_toy_data():
x1 = np.random.normal(size=(100, 2))
x1 += np.array([-5, -5])
x2 = np.random.normal(size=(100, 2))
x2 += np.array([5, -5])
x3 = np.random.normal(size=(100, 2))
x3 += np.array([0, 5])
return np.vstack((x1, x2, x3))
def main():
X = create_toy_data()
model = GaussianMixture(3)
model.fit(X, iter_max=100)
labels = model.classify(X)
x_test, y_test = np.meshgrid(np.linspace(-10, 10, 100), np.linspace(-10, 10, 100))
X_test = np.array([x_test, y_test]).reshape(2, -1).transpose()
probs = model.predict_proba(X_test)
Probs = probs.reshape(100, 100)
colors = ["red", "blue", "green"]
plt.scatter(X[:, 0], X[:, 1], c=[colors[int(label)] for label in labels])
plt.contour(x_test, y_test, Probs)
plt.xlim(-10, 10)
plt.ylim(-10, 10)
plt.show()
if __name__ == '__main__':
main()
Les paramètres de la distribution gaussienne mixte sont très probablement estimés en utilisant les points comme données d'apprentissage, et la distribution de probabilité est illustrée par des courbes de niveau. De plus, la couleur des points indique à quel cluster il appartient. C'est le résultat du succès, mais ** échoue parfois **. Comme décrit dans PRML, maximiser la fonction de vraisemblance logarithmique cette fois est un mauvais problème de réglage et peut ne pas être une bonne solution. Il existe des heuristiques qui fonctionnent autour de cela, mais cette fois, cela échoue car nous n'avons pas implémenté de solution de contournement. Cependant, cela ne devrait pas trop échouer.
Le clustering non supervisé a été effectué par ajustement avec une distribution gaussienne mixte. Le nombre de distributions gaussiennes utilisées à ce moment-là est spécifié ici. Dans le prochain chapitre 10, nous présenterons une méthode pour estimer automatiquement le nombre d'éléments d'une distribution gaussienne mixte appropriée, donc la prochaine fois, nous implémenterons les variantes de baies utilisées ici.
Recommended Posts