Article sur la migration n ° 3.
Cette fois, c'est une expérience simple. Notez le code d'écriture.
L'algorithme EM n'est pas facile à utiliser (je pense personnellement) car l'équation à résoudre change en fonction du modèle. Par conséquent, envisagez de l'obtenir en échantillonnant directement à partir du modèle. Ici, l'échantillonnage est effectué à l'aide de pymc3.
Les données à échantillonner ont été générées comme suit.
N=1000
X, y = datasets.make_blobs(n_samples=N, random_state=8)
transformation = [[0.6, -0.6], [-0.4, 0.8]]
X_aniso = np.dot(X, transformation)
df = pd.DataFrame()
df['x'] = X_aniso.T[0]
df['y'] = X_aniso.T[1]
df['c'] = y
Les données d'entraînement sont tracées comme suit. On peut voir visuellement qu'il y a trois groupes.
plt.figure(figsize=(10, 10))
sns.scatterplot(x='x', y='y', data=df, hue='c')
plt.show()
Soit les données d'observation $ x $, le cluster $ z $ et le paramètre $ \ theta_z $. On suppose que ceux-ci sont générés comme suit.
\displaystyle{
\begin{aligned}
x_i &\sim N(x|, \mu_{z_i}, I)\\
\mu_k &\sim N(\mu_k| 0, I)\\
z_i &\sim Cat(z_i|\pi)\\
\pi &\sim Dir(\pi|\alpha)
\end{aligned}
}
Lorsque cela a été écrit par programme, il est devenu comme suit. La bibliothèque a utilisé pymc3. En regardant les données, il était clair que le nombre de grappes était de 3, donc ici nous avons échantillonné avec 3 grappes.
k=3
data_dim = len(df.T) -1
data_size = len(data)
with pm.Model() as model:
pi = pm.Dirichlet('p', a=np.ones(k), shape=k)
pi_min_potential = pm.Potential('pi_min_potential', tt.switch(tt.min(pi) < .1, -np.inf, 0))
z = pm.Categorical('z', p=pi, shape=data_size)
mus = pm.MvNormal('mus', mu=np.zeros(data_dim), cov=np.eye(data_dim), shape=(k, data_dim))
y = pm.MvNormal('obs', mu=mus[z], cov=np.eye(data_dim), observed=df.drop(columns='c').to_numpy())
tr = pm.sample(10*data_size, random_seed=0, chains=1)
De plus, parmi les résultats d'échantillonnage obtenus, la grappe utilise la valeur la plus fréquente.
df['pred'] = scipy.stats.mode(tr['z'], axis=0).mode[0]
Tracez les résultats en utilisant le code ci-dessous. Vous pouvez voir qu'ils sont bien regroupés.
plt.figure(figsize=(10, 10))
sns.scatterplot(x='x', y='y', data=df, hue='pred')
plt.show()
--Bayes Statistical Modeling with Python: Data Analysis Practice Guide with PyMC