Cette fois, afin d'approfondir la compréhension de LSTM, j'écrirai LSTM avec scratch dans TensorFlow.
Le schéma de principe de LSTM avec Forget_gate est le suivant, et vous pouvez voir qu'il se compose de 4 petits réseaux (** output_gate, input_gate, forget_gate, z **).
** Z ** veut augmenter le poids W afin qu'il ne soit pas oublié s'il y a une entrée dont vous voulez vous souvenir, mais si vous augmentez W, vous vous souviendrez également des informations dont vous n'avez pas besoin de vous souvenir, donc à la fin les informations que vous vouliez retenir Sera écrasé. Ceci est appelé ** conflit de poids d'entrée **. Pour éviter cela, ** input_gate ** bloque les informations non pertinentes et empêche leur écriture dans la cellule de mémoire C. ** forget_gate ** efface les informations de la cellule de mémoire C si nécessaire. En effet, la série de données chronologiques peut changer à la fois lorsque certaines conditions sont remplies et il est nécessaire de réinitialiser les informations qui ont été mémorisées à ce moment-là.
** output_gate ** ne lit pas tout le contenu de la cellule mémoire C, mais efface ceux qui ne sont pas nécessaires, comme dans le cas de l'entrée, pour éviter ** le conflit de poids de sortie **.
Les quatre pondérations de réseau self.W et biais self.B ont la même forme, nous les déclarons donc ensemble.
self.W = tf.Variable(tf.zeros([input_size + hidden_size, hidden_size *4 ]))
self.B = tf.Variable(tf.zeros([hidden_size * 4 ]))
Code de propagation avant. Cette fois, pour faciliter le post-traitement, h et c sont empilés, alors restaurez-les d'abord. Il calcule ensuite la somme linéaire pondérée des quatre réseaux ensemble et divise le résultat en quatre.
def forward(self, prev_state, x):
# h,Restaurer c
h, c = tf.unstack(prev_state)
#Calculer la somme linéaire pondérée de quatre réseaux ensemble
inputs = tf.concat([x, h], axis=1)
inputs = tf.matmul(inputs, self.W) + self.B
z, i, f, o = tf.split(inputs, 4, axis=1)
Passez le sigmoïde à travers les signaux des trois portes.
#Passer sigmoïde à travers le signal de chaque porte
input_gate = tf.sigmoid(i)
forget_gate = tf.sigmoid(f)
output_gate = tf.sigmoid(o)
Mettez à jour la cellule de mémoire en fonction des entrées de la porte et de la couche intermédiaire pour calculer la sortie de la couche intermédiaire. Il n'y a pas de problème même s'il n'y a pas de tanh avant output_gate, donc il est omis.
#Mise à jour de la cellule mémoire, calcul de sortie intermédiaire
next_c = c * forget_gate + tf.nn.tanh(z) * input_gate
next_h = next_c * output_gate
#Pile due au post-traitement
return tf.stack([next_h, next_c])
Utilisons maintenant ce LSTM pour écrire le code qui effectue réellement la prédiction. L'ensemble de données utilise une image de ** dagits ** avec des ** nombres 0-9 ** (plus petits 8 * 8 pixels).
Sur la base du résultat de l'analyse d'une ligne de données par ligne huit fois, laissez LSTM prédire le nombre.
import numpy as np
import tensorflow as tf
from sklearn import datasets
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
class LSTM(object):
def __init__(self, input_size, hidden_size, output_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
#Couche d'entrée
self.inputs = tf.placeholder(tf.float32, shape=[None, None, self.input_size], name='inputs')
self.W = tf.Variable(tf.zeros([input_size + hidden_size, hidden_size *4 ]))
self.B = tf.Variable(tf.zeros([hidden_size * 4 ]))
#Couche de sortie
self.Wv = tf.Variable(tf.truncated_normal([hidden_size, output_size], mean=0, stddev=0.01))
self.bv = tf.Variable(tf.truncated_normal([output_size], mean=0, stddev=0.01))
self.init_hidden = tf.matmul(self.inputs[:,0,:], tf.zeros([input_size, hidden_size]))
self.init_hidden = tf.stack([self.init_hidden, self.init_hidden])
self.input_fn = self._get_batch_input(self.inputs)
def forward(self, prev_state, x):
# h,Restaurer c
h, c = tf.unstack(prev_state)
#Calculer la somme linéaire pondérée de quatre réseaux ensemble
inputs = tf.concat([x, h], axis=1)
inputs = tf.matmul(inputs, self.W) + self.B
z, i, f, o = tf.split(inputs, 4, axis=1)
#Passer sigmoïde à travers le signal de chaque porte
input_gate = tf.sigmoid(i)
forget_gate = tf.sigmoid(f)
output_gate = tf.sigmoid(o)
#Mise à jour de la cellule mémoire, calcul de sortie intermédiaire
next_c = c * forget_gate + tf.nn.tanh(z) * input_gate
next_h = next_c * output_gate
#Pile due au post-traitement
return tf.stack([next_h, next_c])
def _get_batch_input(self, inputs):
return tf.transpose(tf.transpose(inputs, perm=[2, 0, 1]))
def calc_all_layers(self):
all_hidden_states = tf.scan(self.forward, self.input_fn, initializer=self.init_hidden, name='states')
return all_hidden_states[:, 0, :, :]
def calc_output(self, state):
return tf.nn.tanh(tf.matmul(state, self.Wv) + self.bv)
def calc_outputs(self):
all_states = self.calc_all_layers()
all_outputs = tf.map_fn(self.calc_output, all_states)
return all_outputs
#Charger le jeu de données( 8*8 image of a digit)
digits = datasets.load_digits()
X = digits.images
Y_= digits.target
Y=tf.keras.utils.to_categorical(Y_, 10)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
print(Y.shape)
#Exécution prédictive
hidden_size = 50
input_size = 8
output_size = 10
y = tf.placeholder(tf.float32, shape=[None, output_size], name='inputs')
lstm = LSTM(input_size, hidden_size, output_size)
outputs = lstm.calc_outputs()
last_output = outputs[-1]
output = tf.nn.softmax(last_output)
loss = -tf.reduce_sum(y * tf.log(output))
train_step = tf.train.AdamOptimizer().minimize(loss)
correct_predictions = tf.equal(tf.argmax(y, 1), tf.argmax(output, 1))
acc = (tf.reduce_mean(tf.cast(correct_predictions, tf.float32)))
sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
log_loss = []
log_acc = []
log_val_acc = []
for epoch in range(100):
start=0
end=100
for i in range(14):
X=X_train[start:end]
Y=y_train[start:end]
start=end
end=start+100
sess.run(train_step,feed_dict={lstm.inputs:X, y:Y})
log_loss.append(sess.run(loss,feed_dict={lstm.inputs:X, y:Y}))
log_acc.append(sess.run(acc,feed_dict={lstm.inputs:X_train[:500], y:y_train[:500]}))
log_val_acc.append(sess.run(acc,feed_dict={lstm.inputs:X_test, y:y_test}))
print("\r[%s] loss: %s acc: %s val acc: %s"%(epoch, log_loss[-1], log_acc[-1], log_val_acc[-1])),
#Graphisme acc
plt.ylim(0., 1.)
plt.plot(log_acc, label='acc')
plt.plot(log_val_acc, label = 'val_acc')
plt.legend()
plt.show()
C'est une précision relativement bonne.
Recommended Posts