Estimation de probabilité postérieure maximale des séries d'étiquettes avec HMM en utilisant l'interface python d'OpenGM.
Préparation
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import opengm
Les données sont les suivantes. Essayez d'intégrer et de lisser les données de séries chronologiques.
Les données
d = '17.2 19.7 21.6 21.3 22.1 20.5 16.3 18.4 21.0 16.1 17.5 18.5 18.4 18.3 16.0 21.2 18.8 24.3 23.3 20.5 16.9 22.4 20.1 24.5 24.2 22.7 19.6 23.6 23.3 24.6 25.0 24.3 22.2 22.7 19.5 20.5 17.3 17.2 22.0 20.9 21.5 22.3 24.0 22.4 20.2 15.7 20.4 16.3 17.7 14.3 18.4 16.6 13.9 15.2 14.8 15.0 11.5 13.4 13.5 17.0 15.0 17.5 12.3 11.8 14.5 12.4 12.9 15.8 13.8 11.4 6.5 5.9 7.2 5.6 4.6 7.5 8.9 6.6 3.9 5.7 7.3 6.1 6.8 3.1 2.6 7.9 5.2 2.0 4.0 3.4 5.7 8.1 4.7 5.4 5.9 3.6 2.9 5.7 2.1 1.6 2.3 2.4 1.2 4.2 4.2 2.4 5.6 2.5 3.0 6.1 4.9 7.1 5.0 7.2 5.2 5.1 10.4 8.3 6.9 6.8 7.8 4.2 8.0 3.2 7.9 5.9 9.5 6.4 9.2 11.7 11.6 15.5 16.7'
d = np.array([ float(c) for c in d.split()])
Maintenant, construisez le modèle HMM et exécutez l'inférence.
Courir!
nNodes = d.shape[0] #Nombre de nœuds.
nLabels = 20 #Nombre de classes discrètes. 20
variableSpace = np.ones(nNodes)*nLabels #Nombre d'étiquettes pour chaque nœud. Tout de même ici
gm = opengm.gm(variableSpace)
# unary
for i in range(nNodes):
u = np.array([ abs(d[i] - j) for j in range(nLabels) ]) #Terme de données. Valeur absolue de la différence avec l'étiquette
f = gm.addFunction(u)
gm.addFactor(f, i)
# pairwise
p = 10 #Coût lorsque les classes des nœuds adjacents sont différentes. (0 si identique)
pairwise = np.array((np.ma.ones((nLabels,nLabels)) - np.eye(nLabels)) * p) #Terme par paires. 0 pour le même label, p pour différents
f_pw = gm.addFunction(pairwise)
for i in range(nNodes-1):
gm.addFactor(f_pw, [i, i+1]) #Définition du bord qui relie les nœuds. Puisqu'il s'agit d'un HMM, il est unidimensionnel.
inf = opengm.inference.DynamicProgramming(gm=gm) #Algorithme d'inférence: DP suffit car il est unidimensionnel
inf.infer() #Exécution d'inférence
res = inf.arg() #Collectez les résultats
#terrain.
plt.plot(d, label="data")
plt.plot(res, label="result")
plt.legend()
Recommended Posts