This time, in order to deepen the understanding of LSTM, I will write LSTM with scratch in TensorFlow.
The block diagram of the LSTM with Forget_gate looks like this, and you can see that it consists of four small networks (** output_gate, input_gate, forget_gate, z **).
** Z ** wants to increase the weight W so that it will not be forgotten if there is an input that you want to remember, but if you increase W, you will also remember information that you do not need to remember, so in the end the information you wanted to remember Will be overwritten. This is called ** input weight conflict **. To avoid that, ** input_gate ** blocks irrelevant information and prevents it from being written to memory cell C. ** forget_gate ** erases information in memory cell C if necessary. This is because the time series data may change at once when certain conditions are met, so it is necessary to reset the information that was memorized at that time.
** output_gate ** erases unnecessary things instead of reading the entire contents of memory cell C to avoid ** output weight conflicts **, as in the case of input.
Since the shapes of the weight self.W and the bias self.B of the four networks are the same, we declare them together.
self.W = tf.Variable(tf.zeros([input_size + hidden_size, hidden_size *4 ]))
self.B = tf.Variable(tf.zeros([hidden_size * 4 ]))
Forward propagation code. This time, for convenience of post-processing, h and c are stacked, so restore them first. It then calculates the weighted linear sum of the four networks together and divides the result into four.
def forward(self, prev_state, x):
# h,Restore c
h, c = tf.unstack(prev_state)
#Calculate the weighted linear sum of four networks together
inputs = tf.concat([x, h], axis=1)
inputs = tf.matmul(inputs, self.W) + self.B
z, i, f, o = tf.split(inputs, 4, axis=1)
Pass the sigmoid through the signals from the three gates.
#Pass the sigmoid through the signal of each gate
input_gate = tf.sigmoid(i)
forget_gate = tf.sigmoid(f)
output_gate = tf.sigmoid(o)
Calculate the intermediate layer output by updating the memory cells based on the gate and intermediate layer inputs. There is no problem even if there is no tanh before output_gate, so it is omitted.
#Memory cell update, intermediate output calculation
next_c = c * forget_gate + tf.nn.tanh(z) * input_gate
next_h = next_c * output_gate
#Stack due to post-processing
return tf.stack([next_h, next_c])
Now let's use this LSTM to write the code that actually performs the prediction. The dataset uses an image of ** dagits ** with the numbers 0-9 ** (smaller 8 * 8 pixels).
Based on the result of scanning one data line by line eight times, let the LSTM predict what the number is.
import numpy as np
import tensorflow as tf
from sklearn import datasets
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
class LSTM(object):
def __init__(self, input_size, hidden_size, output_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
#Input layer
self.inputs = tf.placeholder(tf.float32, shape=[None, None, self.input_size], name='inputs')
self.W = tf.Variable(tf.zeros([input_size + hidden_size, hidden_size *4 ]))
self.B = tf.Variable(tf.zeros([hidden_size * 4 ]))
#Output layer
self.Wv = tf.Variable(tf.truncated_normal([hidden_size, output_size], mean=0, stddev=0.01))
self.bv = tf.Variable(tf.truncated_normal([output_size], mean=0, stddev=0.01))
self.init_hidden = tf.matmul(self.inputs[:,0,:], tf.zeros([input_size, hidden_size]))
self.init_hidden = tf.stack([self.init_hidden, self.init_hidden])
self.input_fn = self._get_batch_input(self.inputs)
def forward(self, prev_state, x):
# h,Restore c
h, c = tf.unstack(prev_state)
#Calculate the weighted linear sum of four networks together
inputs = tf.concat([x, h], axis=1)
inputs = tf.matmul(inputs, self.W) + self.B
z, i, f, o = tf.split(inputs, 4, axis=1)
#Pass the sigmoid through the signal of each gate
input_gate = tf.sigmoid(i)
forget_gate = tf.sigmoid(f)
output_gate = tf.sigmoid(o)
#Memory cell update, intermediate output calculation
next_c = c * forget_gate + tf.nn.tanh(z) * input_gate
next_h = next_c * output_gate
#Stack due to post-processing
return tf.stack([next_h, next_c])
def _get_batch_input(self, inputs):
return tf.transpose(tf.transpose(inputs, perm=[2, 0, 1]))
def calc_all_layers(self):
all_hidden_states = tf.scan(self.forward, self.input_fn, initializer=self.init_hidden, name='states')
return all_hidden_states[:, 0, :, :]
def calc_output(self, state):
return tf.nn.tanh(tf.matmul(state, self.Wv) + self.bv)
def calc_outputs(self):
all_states = self.calc_all_layers()
all_outputs = tf.map_fn(self.calc_output, all_states)
return all_outputs
#Data set loading( 8*8 image of a digit)
digits = datasets.load_digits()
X = digits.images
Y_= digits.target
Y=tf.keras.utils.to_categorical(Y_, 10)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
print(Y.shape)
#Predictive execution
hidden_size = 50
input_size = 8
output_size = 10
y = tf.placeholder(tf.float32, shape=[None, output_size], name='inputs')
lstm = LSTM(input_size, hidden_size, output_size)
outputs = lstm.calc_outputs()
last_output = outputs[-1]
output = tf.nn.softmax(last_output)
loss = -tf.reduce_sum(y * tf.log(output))
train_step = tf.train.AdamOptimizer().minimize(loss)
correct_predictions = tf.equal(tf.argmax(y, 1), tf.argmax(output, 1))
acc = (tf.reduce_mean(tf.cast(correct_predictions, tf.float32)))
sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
log_loss = []
log_acc = []
log_val_acc = []
for epoch in range(100):
start=0
end=100
for i in range(14):
X=X_train[start:end]
Y=y_train[start:end]
start=end
end=start+100
sess.run(train_step,feed_dict={lstm.inputs:X, y:Y})
log_loss.append(sess.run(loss,feed_dict={lstm.inputs:X, y:Y}))
log_acc.append(sess.run(acc,feed_dict={lstm.inputs:X_train[:500], y:y_train[:500]}))
log_val_acc.append(sess.run(acc,feed_dict={lstm.inputs:X_test, y:y_test}))
print("\r[%s] loss: %s acc: %s val acc: %s"%(epoch, log_loss[-1], log_acc[-1], log_val_acc[-1])),
#acc Graphing
plt.ylim(0., 1.)
plt.plot(log_acc, label='acc')
plt.plot(log_val_acc, label = 'val_acc')
plt.legend()
plt.show()
It's relatively good accuracy.
Recommended Posts