Get startedGet started for free

LSTM and GRU cells

1. LSTM and GRU cells

Let's discuss recurrent architectures more powerful than a plain RNN.

2. Short-term memory problem

Because RNN neurons pass the hidden state from one time step to the next, they can be said to maintain some sort of memory. That's why they are often called RNN memory cells, or just cells for short. However, this memory is very short-term: by the time a long sentence is processed, the hidden state doesn't have much information about its beginning. Imagine trying to translate a sentence between languages; as soon as we have read it, we don't remember how it started. To solve this short-term memory problem, two more powerful types of cells have been proposed: the Long Short-Term Memory or LSTM cell and the Gated Recurrent Unit or GRU cell.

3. RNN cell

Before we look at LSTM and GRU cells, let's visualize the plain RNN cell. At each time step t, it takes two inputs, the current input data x and the previous hidden state h. It multiplies these inputs with the weights, applies activation, and outputs two things: the current outputs y and the next hidden state.

4. LSTM cell

The LSTM cell has three inputs and outputs. Next to the input data x, there are two hidden states: h represents the short-term memory and c the long-term memory. At each time step, h and x are passed through some linear layers called gate controllers which determine what is important enough to keep in the long-term memory. The gate controllers first erase some parts of the long-term memory in the forget gate. Then, they analyze x and h and store their most important parts in the long-term memory in the input gate. This long-term memory, c, is one of the outputs of the cell. At the same time, another gate called the output gate determines what the current output y should be. The short-term memory output h is the same as y.

5. LSTM in PyTorch

Building an LSTM network in PyTorch is very similar to the plain RNN we have already seen. In the init method, we only need to use the nn.LSTM layer instead of nn.RNN. The arguments that the layer takes as inputs are the same. In the forward method, we add the long-term hidden state c and initialize both h and c with zeros. Then, we pass h and c as a tuple to the LSTM layer. Finally, we take the last output, pass it through the linear layer and return just like before.

6. GRU cell

The GRU cell is a simplified version of the LSTM cell. It merges the long-term and short-term memories into a single hidden state. It also doesn't use an output gate: the entire hidden state is returned at each time step.

7. GRU in PyTorch

Building a GRU network in PyTorch is almost identical to the plain RNN. All we need to do is replace the nn.rnn with nn.gru when defining the layer in the init method, and then call the new gru layer in the forward method.

8. Should I use RNN, LSTM, or GRU?

So, which type of recurrent network should we use: the plain RNN, LSTM, or GRU? There is no single answer, but consider the following. Although plain RNNs have revolutionized modeling of sequential data and are important to understand, they are not used much these days because of the short-term memory problem. Our choice will likely be between LSTM and GRU. GRU's advantage is that it's less complex than LSTM, which means less computation. Other than that, the relative performance of GRU and LSTM varies per use case, so it's often a good idea to try both and compare the results. We will learn how to evaluate these models soon.

9. Let's practice!

Let's practice!