Get startedGet started for free

LSTM network

As you already know, plain RNN cells are not used that much in practice. A more frequently used alternative that ensures a much better handling of long sequences are Long Short-Term Memory cells, or LSTMs. In this exercise, you will be build an LSTM network yourself!

The most important implementation difference from the RNN network you have built previously comes from the fact that LSTMs have two rather than one hidden states. This means you will need to initialize this additional hidden state and pass it to the LSTM cell.

torch and torch.nn have already been imported for you, so start coding!

This exercise is part of the course

Intermediate Deep Learning with PyTorch

View Course

Exercise instructions

  • In the .__init__() method, define an LSTM layer and assign it to self.lstm.
  • In the forward() method, initialize the first long-term memory hidden state c0 with zeros.
  • In the forward() method, pass all three inputs to the LSTM layer: the current time step's inputs, and a tuple containing the two hidden states.

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

class Net(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        # Define lstm layer
        ____ = ____(
            input_size=1,
            hidden_size=32,
            num_layers=2,
            batch_first=True,
        )
        self.fc = nn.Linear(32, 1)

    def forward(self, x):
        h0 = torch.zeros(2, x.size(0), 32)
        # Initialize long-term memory
        c0 = ____
        # Pass all inputs to lstm layer
        out, _ = ____
        out = self.fc(out[:, -1, :])
        return out
Edit and Run Code