Click here for the paper (Machine Learning on Sequential Data Using a Recurrent Weighted Average) In the figure above, c is a schematic diagram of the RWA model (a is a normal LSTM, b is an LSTM with attention).
RWA is one of the derivations of Recurrent Neural Networks (* RNN *) that handles series data. In the proposed paper, compared to the LSTM, which is often used as an implementation of RNN,
--Good accuracy --Fast convergence --Small number of parameters
And, ** all good things ** are written. I was surprised at the strength of the claim and the simplicity of the architecture, and wondered if it could really surpass the LSTM, which is now almost the de facto standard, so this time I implemented [RWA's Keras implementation](https: // gist. I rewrote github.com/shamatar/55b804cf62b8ee0fa23efdb3ea5a4701) a little and tried to reproduce some of the experiments in the paper.
You can think of RWA as a generalization of Attention and a recursive redefinition that incorporates it into the structure of the RNN. In other words, attention (in RNNs) is a special case of RWA.
Let's take a closer look.
All RNNs, not just LSTMs, are networks that process series data.
Since it is easy to process series data by assuming the Markov process (= the current state is determined only by the current data and the past state), enter "current data" and "past state" in RNN. Is recursively modeled to output the "current state".
If you write in the formula,
However, as you can see in the top figure, the Attention model is not recursively defined and cannot be represented in the form of the function $ Recurrent $. [^ 2] RWA considers attention as a moving average of past states, and transforms the expression equivalently to reduce it to the form of the $ Recurrent $ function.
Specifically, RWA takes a moving average of past states as follows: f is a suitable activation function, where z controls the recursive state-transforming term, and a controls how much weight is placed on averaging past states (a corresponds to attention). To do). If this formula is left as it is, it cannot be said to be recursive because Σ includes the operation of "adding past states from 1 to t". Somehow, I want to redefine equation (2) only in the "previous state".
Now, let's divide the inside of f in Eq. (2) into the denominator d and the numerator n. Then we know that n and d are cumulative sums, so we can rewrite them as follows: At this point, n and d have been transformed into a format that depends only on the previous point in time. That's all for the essence.
After that, z is slightly changed from a normal RNN, and the expression is divided into a term u (that is, embedding) that looks only at the input and a term g that also looks at the state. Is the mathematical formula of RWA. [^ 3]
RWAs have much the same structure as the simplest RNNs. Since RWA originally started from the "refer to all past states" format, it is expected that the state can be updated by referring to the past state at any time even if there is no internal Forget gate or output gate like LSTM. Will be done.
I experimented with code that modified the Keras implementation of RWA published by a third party (https://gist.github.com/shamatar/55b804cf62b8ee0fa23efdb3ea5a4701) so that the return_sequences
parameter is valid.
Click here for modified code and experiment / visualization script
(return_sequences
is a parameter that allows you to set whether to output the history of all past states, not just the last state, in Keras' Recurrent Layer. Without this, you cannot visualize the state later.)
The easiest to implement of the experiments mentioned in the paper
We conducted experiments with two types.
Classifying by Sequence Length It is a problem to judge "whether the length of the given series data exceeds a certain length?". Prepare a vector whose length changes randomly in the range of 0 or more and 1000 or less, and if the length of the vector exceeds 500, it is judged as 1, otherwise it is judged as 0. The value of each element of the vector is appropriately taken from the normal distribution and filled in (Note: The element value of the vector is not related to this problem. It is the length of the vector that is related to this problem) The objective function is binary_crossentropy. In the paper, the mini-batch size was set to 100, but since it was troublesome to incorporate data with different series lengths into the same batch, the batch size was set to 1 in the experiment for this problem (it takes a lot of time). The following results were obtained in about 12 hours using GPU).
The experimental results are as follows (Vertical axis: Accuracy (higher is better), Horizontal axis: Number of epochs)
--Results of the dissertation
--Results of this experiment Due to time constraints, the LSTM is still in the process of learning, but it was the same as the result of the paper in that the RWA converged overwhelmingly fast (how many samples are learned and converged because the batch size is different. I can't say).
While processing the data, I was wondering what the state of RWA was, so I also plotted it. The vertical axis is the dimension of time and the horizontal axis is the dimension of state (250).
The figure above is an example with a series length of 1000 (that is, the prediction result should be "1"). In this case, I was able to predict correctly. Looking at the plot of the state, it seems that the state changes when the series length is close to 500, and it seems that the state is like a gradation in the time direction as a whole. Apparently I was able to learn correctly. I tried various tests with different lengths of the series, but the accuracy deteriorated sharply when the series length was around 500, while the accuracy was 100% for the series that were extremely short or long. (The above figure is also an example of extremely long series length)
Adding Problem
The problem is "prepare a vector of appropriate length and add two randomly selected values". The data given to the model is two vectors of length n. One is a real number vector, the other is a vector with 1s standing in only two places and the rest being 0s. Let them learn to add up the real numbers where 1 stands. The objective function is MSE. This problem was experimented with a mini-batch size of 100 as per the paper. The experiment time is less than an hour using GPU.
The experimental results are as follows (Vertical axis: MSE (lower is better), horizontal axis: number of epochs)
--Results of the dissertation
--Results of this experiment (length 100)
--Results of this experiment (length 1000)
Please note that the scale on the horizontal axis has changed (since I experimented with 1epoch = 100batch, multiplying the value on the horizontal axis by 100 will result in the same scale as the original paper). Regarding RWA, I was able to reproduce the results of the paper. LSTM gave the same results as the paper for length 100, but did not learn well for length 1000. Compared with the state of convergence of LSTM as a result of the original paper, does the accuracy start to improve with an additional 100 epoch of learning for a series of length 1000?
RWA also insists that it can solve any problem (within the range I tried) without having to mess with hyperparameters and initialization settings, so rather only RWA can reproduce the results of the paper in one shot. It may be more desirable as a follow-up exam.
The state of RWA is as follows (one graph corresponds to one sample) The vertical axis is time (100 or 1000) and the horizontal axis is the state dimension (250). Where written above the figure is the data of where the correct flag was located.
When you find the elements to add (that is, where), you can see that some of the dimensions of the state are changing rapidly. Certainly, it seems that learning is possible so that the events contained in the series data can be detected.
Personally, I feel that RWA is much simpler and easier to understand than LSTM, and that it is a good way to realize intuitive ideas. In the proposed paper, only the simplest comparison with LSTM is made, and the problem is how it compares with LSTM with attention, and if layers are stacked to make it multi-layer (stacked) as is often done with LSTM. I still don't know what will happen. (However, the situation that can be applied to the attention model is limited, and since RWA is like a generalization of attention, it may not be compared ...) I think that if more research is done in the future, RWA may be used as the de facto standard by replacing LSTM.
[^ 1]: If it is expressed by a linear model without a state, it is called AR, and if it is expressed by a hidden Markov model in which the equation is explicitly written, it is called a state space model. [^ 2]: Because it depends on all past states, not just the previous state. [^ 3]: In order to reduce the numerical error, n and d are transformed into equivalents in the implementation. See Appendix B of the paper for details.
Recommended Posts