The simple RNN described in the previous post, when using sigmoid or hyperbolic tangent as an activation function, tends to suffer from "vanishing gradients" when trained to work with long sequences. What does it mean to have vanishing gradients?
In order to train your RNN you need to find the gradient of the error with respect to each weight in order to know whether to increase or decrease the weight. Now let's assume that we're using $tanh$ as an activation function for computing the next state. The function that turns an input and a state into a new state is something like the following:
$s^1 = tanh(i^1 w_i + s^0 w_s)$
We're leaving out the bias term in order to keep things short. If you want to find the gradient of $s^1$ with respect to $w_i$ (for a sequence length one), that gradient will be
$\frac{ds^1}{dw_i} = (1 - tanh^2(i^1 w_i + s^0 w_s))(i^1)$
Notice how the gradient involves a multiplication with $tanh^2$ of the input. $tanh$ gives a number between 1 and -1 which is a fraction and squaring a fraction makes it smaller. So the gradient might be a pretty small number.
What happens when we have a sequence of two inputs?
$s^2 = tanh(i^2 w_i + tanh(i^1 w_i + s^0 w_s) w_s)$
The gradient will be:
$\frac{ds^2}{dw_i} = (1 - tanh^2(\dots))(i^2 + (1 - tanh^2(\dots))(i^1))$
Notice how now we've multiplied the already small number produced by $tanh^2$ by another number produced by $tanh^2$ before getting to our first input $i^1$. Multiplying a fraction by another fraction results in a fraction that's smaller than both fractions. This means that $i^1$ will contribute significantly less to the gradient than $i^2$. If we have a sequence of 3 inputs then the first input would require 3 fractional multiplications which makes it even smaller. This will eventually make the gradient negligible (it vanishes) and hence result in the RNN only learning to reduce the error with respect to recent inputs only, completely forgetting inputs towards the beginning of the sequence.
One solution for this problem is to use activation functions which do not result in fractional multiplications for each time step, such as the rectified linear unit function (ReLU). $ReLU$ is a function that returns numbers unchanged if they're positive whilst replacing negative numbers with zero:
$ReLU(x) = max(x, 0)$
The gradient of this function is either 1 (if $x$ is positive) or 0 (if $x$ is negative). In mathematical terms this is equivalent to the indicator function $1_{x>0}$ which gives 1 when the subscript condition is met and 0 otherwise. So if we replace our $tanh$ with $ReLU$ for our previous equation we'll get the following:
$s^2 = ReLU(i^2 w_i + ReLU(i^1 w_i + s^0 w_s) w_s)$
$\frac{ds^2}{dw_i} = 1_{i^2 w_i + ReLU(i^1 w_i + s^0 w_s) w_s > 0}(i^2 + 1_{i^1 w_i + s^0 w_s > 0}(i^1))$
You might be worried about how a single failed condition in the indicator functions will result in the making all previous time steps being multiplied by 0 which will vanish them completely. But keep in mind that $i$ and $s$ are not single numbers but vectors of numbers, and that the $ReLU$ function will work on each number in the vectors independently. So whilst some numbers in the vectors will be negative, others will not. With large enough vectors it is unlikely that an entire vector will consist of just negative numbers so you're likely to have a number of cases with only multiplications by 1 (never 0) which will preserve at least parts of the vectors of early inputs.
Although simple RNNs with ReLUs have been reported to give good results in some papers such as this, a more complex form of RNN called a long short term memory (LSTM) RNN, which was designed to reduce the problem of vanishing gradients, is much more popular. According to one of the authors of the LSTM (Jurgen Schmidhuber), its name means that it is an RNN with a short term memory that is long.
The idea is to make the state pass through to the next state without going through an activation function. You pass the input through an activation function, but the previous state is just added to it linearly. So instead of having this state update:
$s^1 = tanh(i^1 w_i + s^0 w_s)$
the LSTM has this state update:
$s^1 = tanh(i^1 w_i) + s^0$
This changes how the gradient changes as the sequence gets longer because now we don't have nested activation functions any more. Instead this is what you end up with when you add another time step:
$s^2 = tanh(i^2 w_i) + tanh(i^1 w_i) + s^0$
This equation perfectly preserves the gradients as it gets longer since all you're doing is adding another term but it also loses a lot of its expressive power since the order in which you present the items in a sequence doesn't matter any more. You'll get the same state vector regardless of how you shuffle the sequence.
What we need is a mixture of the expressiveness of the simple RNN mentioned in the beginning and gradient preservation of this new function. The solution is to have two state vectors: one for expressiveness called the "hidden state" $h$ and one for gradient preservation called the "cell state" $c$. Here is the new equation:
$c^1 = tanh(i^1 w_i + h^0 w_h) + c^0$
$h^1 = tanh(c^1)$
Notice the following things:
- The order of the inputs matters now because each input is combined with a different hidden state vector.
- The hidden state vector is derived from the cell state vector by just passing the cell state through a $tanh$.
- Even though the hidden state is derived from the cell state, $h^0$ and $c^0$ are separate constants and $h^0$ is not equal to $tanh(c^0)$.
Let's derive the equation for the final state of a length 2 sequence step by step:
$c^2 = tanh(i^2 w_i + h^1 w_h) + c^1$
$h^2 = tanh(c^2)$
$c^2 = tanh(i^2 w_i + tanh(c^1) w_h) + tanh(i^1 w_i + h^0 w_h) + c^0$
$c^2 = tanh(i^2 w_i + tanh(tanh(i^1 w_i + h^0 w_h)) w_h) + tanh(i^1 w_i + h^0 w_h) + c^0$
You can see that what is happening is that you end up with a separate term for each input where the input of that term is just inside a $tanh$ and no deeper. On the other hand each term also consists of all the previous inputs but at much lower levels of influence since each previous input is within an additional two nested $tanh$ functions the further back in the sequence it is found. This allows for a distinction in the order of the inputs whilst keeping each term focused on one particular input.
The creators of the LSTM didn't stop there however. They also introduced gating functions in the LSTM. A gating function is basically a single layer sub neural net that outputs a fraction between 0 and 1 using a sigmoid. This number is then multiplied by another number in order to either leave the second number unchanged or turn it to zero. In other words, if the gate is 1 then it allows the second number to pass whilst if it is 0 then it blocks the second number (and gates in between will make the number smaller). Originally there were two gates: one for the input terms $tanh(i^1 w_i + h^0 w_h)$ and one for the new hidden state $tanh(c^1)$. The gates control whether to accept the new input (the input gate) and whether to allow previous information to be represented by the hidden state (the output gate). Later, another paper introduced a forget gate that regulates the cell state as well. All of these gates are controlled by sub neural nets that take a decision based on the current input and the previous hidden state.
Here is what the final LSTM equation looks like:
$g_f = sig(i^{t} w_{g_{f}i} + h^{t-1} w_{g_{f}h} + b_{g_f})$
$g_i = sig(i^{t} w_{g_{i}i} + h^{t-1} w_{g_{i}h} + b_{g_i})$
$g_o = sig(i^{t} w_{g_{o}i} + h^{t-1} w_{g_{o}h} + b_{g_o})$
$c^{t} = g_i \odot tanh(i^{t} w_{ci} + h^{t-1} w_{ch} + b_{c}) + g_f \odot c^{t-1}$
$h^{t} = g_o \odot tanh(c^{t})$
where $g_f$, $g_i$, and $g_o$ are the forget gate, input gate, and output gate respectively. $w_{g_{f}i}$ means weight for the input being used in the calculation of $g_f$. $b$ are the bias terms of the sub neural nets (we didn't use biases in the previous equations to keep things short). $\odot$ is the element-wise product of two vectors and $sig$ stands for sigmoid.
Here is a diagram illustrating the different components of the LSTM with gating function application (element-wise multiplication) being represented by triangles:
A little while before, you might have been a little sceptical about using two different state vectors in order to combine the expressivity of the simple RNN with the gradient preservation of the addition based RNN described above. Can't we just use the cell state in the $tanh$ as well? In fact there was later a paper describing another kind of RNN called a gated recurrent unit (GRU) which does just that. Instead of using one state for differentiating between time steps and another for flowing previous states into later calculations, only one state is used. Gating functions are still used, except that the input gate is replaced with 1 minus the forget gate. Basically it finds a compromise between either forgetting previous states and focussing everything on the new input or ignoring the new input and preserving the previous states. Finally, instead of having an output gate we now have a reset gate which is used to control how much the state should contribute to the input term. Here is what the GRU equation looks like:
$g_f = sig(i^{t} w_{g_{f}i} + h^{t-1} w_{g_{f}h} + b_{g_f})$
$g_i = 1 - g_f$
$g_r = sig(i^{t} w_{g_{r}i} + h^{t-1} w_{g_{r}h} + b_{g_r})$
$h^{t} = g_i \odot tanh(i^{t} w_{hi} + w_{hh}(g_r \odot h^{t-1}) + b_{h}) + g_f \odot h^{t-1}$
And here is the diagram:
Just taking the opportunity to encourage you to continue this blog. I wish I had the time to do something like this, I envy and appreciate it. Well done and carry on.
ReplyDelete