Comment apprendre le SVM structuré de ChainCRF avec PyStruct

Les documents et les échantillons sur le Web d'origine ne sont pas conviviaux, j'ai donc essayé d'utiliser des données faciles à comprendre.

Première préparation


import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from pystruct.inference import inference_dispatch

Le contenu est la suppression du bruit des données de séries chronologiques comme dans Implémentation de HMM avec PyStruct. Pour l'apprentissage, une série chronologique avec du bruit ajouté à une série temporelle fixe est utilisée. (Mis à part le fait qu'il est corrigé pour que vous n'ayez pas à en déduire)

Création de données d'entraînement


n_samples = 500

d = np.array([12, 12, 11, 11, 10,  9,  8,  8,  7,  6,  6,  6,  7,  8,  8,  8,  6,
        5,  4,  3,  3,  3,  2,  1,  0,  1,  3,  4,  5,  6,  8,  8,  9,  9,
       10, 11, 12, 13, 14, 14, 14, 15, 15, 15, 15])
n_nodes = d.shape[0]
n_states = np.unique(d).shape[0]
n_features = n_states + 1 # add bias

y = np.repeat(d[np.newaxis,:], n_samples, axis=0)

data = y + (np.random.rand(n_samples, n_nodes)-0.5)*5

# negative sign for maximization !
X = np.array( [ [ [ -abs(i-j)**0.1 for j in range(n_states)]  for i in dd ] for dd in data] )

# add constant features for bias
X = np.array( [np.hstack((X[i], 0.1*np.ones((X[i].shape[0],1)))) for i in range(X.shape[0])] )

Data X a 500 nombres, 45 longueurs de séries chronologiques, 16 états / classes et 17 caractéristiques (biais SVM).

Vérifier la taille


X.shape, y.shape
===
((500, 45, 17), (500, 45))

Divisez l'apprentissage et les tests comme d'habitude


from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)

Vérifier les données d'entraînement


fig, axes = plt.subplots(3,3, figsize=(20,6))
c=0
for ax in axes.ravel():
    ax.plot(data[c], label='data')
    ax.plot(y_train[c], label='true')
    ax.set_xticks(())
    ax.set_yticks(())
    c += 1
plt.legend(bbox_to_anchor=(1.1, 1.0), loc=2, borderaxespad=0.)

Unknown1.png

Comparaison des données d'entraînement X (caractéristiques à chaque instant) et y (vraies séries temporelles fixes) pour confirmation.

Vérification


plt.matshow(np.flipud(X_train[0,:,:-1].T)) # remove bias
plt.colorbar()
plt.yticks(())
#plt.show()

plt.plot(15-y_train[0]) # flipud
plt.show()

Unknown2.png

Maintenant, préparez l'apprenant. Apprenez avec FrancWolfe SSVM selon l'explication de ChainCRF de PyStruct.

Préparation de l'apprenant


from pystruct.models import ChainCRF
from pystruct.learners import FrankWolfeSSVM
model = ChainCRF()
ssvm = FrankWolfeSSVM(model=model, C=.1, max_iter=10)

Apprendre!


%%time
ssvm.fit(X_train, y_train)
====
CPU times: user 1.25 s, sys: 17.4 ms, total: 1.27 s
Wall time: 1.3 s

FrankWolfeSSVM(C=0.1, batch_mode=False, check_dual_every=10,
        do_averaging=True, line_search=True, logger=None, max_iter=10,
        model=ChainCRF(n_states: 16, inference_method: max-product),
        n_jobs=1, random_state=None, sample_method='perm',
        show_loss_every=0, tol=0.001, verbose=0)

Alors quel est le score prédit?


ssvm.score(X_test, y_test)
==========
0.56377777777777771

Vérifiez les prédictions pour le test


X_test_predict = np.array(ssvm.predict(X_test))

fig, axes = plt.subplots(3,3, figsize=(20,6))
shf = np.arange(X_test.shape[0])
np.random.shuffle(shf)
c=0
for ax in axes.ravel():
    ax.plot(data[shf[c]], label='data')
    ax.plot(X_test_predict[shf[c]], label='predict')
    ax.plot(y_test[shf[c]], label='true')
    ax.set_xticks(())
    ax.set_yticks(())
    c += 1

plt.legend(bbox_to_anchor=(1.1, 1.0), loc=2, borderaxespad=0.)

Unknown3.png

Vérifiez le w appris


ssvm.w.shape # = n_features * n_states + n_states**2
========
(528,)

Poids par paire w


plt.matshow(ssvm.w[n_features * n_states:].reshape(n_states, n_states))
plt.title("Transition parameters of the chain CRF.")
plt.xticks(np.arange(n_states))
plt.yticks(np.arange(n_states))
plt.colorbar()
plt.show()

Unknown4.png

poids unaire w


plt.matshow(ssvm.w[:n_features * n_states].reshape(n_states,n_features))
plt.title("Unary parameters of the chain CRF.")
plt.yticks(np.arange(n_states))
plt.xticks(np.arange(n_features))
plt.ylabel('states') 
plt.xlabel('features')
plt.colorbar()
plt.show()

Unknown5.png

Recommended Posts

Comment apprendre le SVM structuré de ChainCRF avec PyStruct
Comment déduire l'estimation MAP de HMM avec PyStruct
[Hugo] Résumé de la façon d'ajouter des pages au site créé avec Learn
Comment spécifier des attributs avec Mock of Python
Comment implémenter "named_scope" de RubyOnRails avec Django
Comment entraîner Kaldi avec JUST Corpus
Comment déduire une estimation MAP de HMM avec OpenGM
[Comment!] Apprenez et jouez à Super Mario avec Tensorflow !!
Résumé de la façon de partager l'état avec plusieurs fonctions
Comment mettre à jour avec SQLAlchemy?
Comment modifier avec SQLAlchemy?
Comment séparer les chaînes avec ','
Comment supprimer avec SQLAlchemy?
Comment activer la lecture / écriture de net.Conn avec Golang pour annuler avec le contexte
Comment annuler RT avec Tweepy
Comment extraire des fonctionnalités de données de séries chronologiques avec les bases de PySpark
Python: comment utiliser async avec
Résumé de l'utilisation de pandas.DataFrame.loc
Comment obtenir l'ID de Type2Tag NXP NTAG213 avec nfcpy
Pour utiliser virtualenv avec PowerShell
Résumé de l'utilisation de pyenv-virtualenv
J'ai essayé de résumer brièvement la procédure de démarrage du développement de Django
Comment démarrer avec Scrapy
Comment gérer l'erreur DistributionNotFound
Comment démarrer avec Django
Comment surveiller l'état d'exécution de sqlldr avec la commande pv
Comment augmenter les données avec PyTorch
Explique comment utiliser TensorFlow 2.X avec l'implémentation de VGG16 / ResNet50
Node.js: Comment tuer les descendants d'un processus démarré par child_process.fork ()
Comment calculer la date avec python
Résumé de l'utilisation de csvkit
Comment INNER JOIN avec SQL Alchemy
Comment installer Anaconda avec pyenv
[EC2] Comment faire une capture d'écran de votre smartphone avec du sélénium
Comment couper la partie inférieure droite de l'image avec Python OpenCV
[Introduction à Python] Comment trier efficacement le contenu d'une liste avec le tri par liste
[Reconnaissance d'image] Comment lire le résultat de l'annotation automatique avec VoTT
Comment gérer les caractères déformés dans json de Django REST Framework
Résumé de la création d'un environnement LAMP + Wordpress avec Sakura VPS
Comment effectuer un traitement arithmétique avec le modèle Django
[Blender] Comment définir shape_key avec un script
[Python] Résumé de l'utilisation des pandas
Comment titrer plusieurs figures avec matplotlib
Comment accélérer la belle instanciation de soupe
Comment obtenir l'identifiant du parent avec sqlalchemy
Apprenez à coloriser les images monochromes avec Chainer
Comment se débarrasser des longues inclusions
Comment configurer SVM à l'aide d'Optuna
Comment installer DLIB avec 2020 / CUDA activé
Comment utiliser ManyToManyField avec l'administrateur de Django
Comment utiliser OpenVPN avec Ubuntu 18.04.3 LTS
Comment utiliser Cmder avec PyCharm (Windows)
Comment empêcher les mises à jour de paquets avec apt
Comment utiliser BigQuery en Python
Comment utiliser Ass / Alembic avec HtoA
Comment gérer les erreurs de compatibilité d'énumération