Aan de slagGa gratis aan de slag

LSTM network

As you already know, plain RNN cells are not used that much in practice. A more frequently used alternative that ensures a much better handling of long sequences are Long Short-Term Memory cells, or LSTMs. In this exercise, you will be build an LSTM network yourself!

The most important implementation difference from the RNN network you have built previously comes from the fact that LSTMs have two rather than one hidden states. This means you will need to initialize this additional hidden state and pass it to the LSTM cell.

torch and torch.nn have already been imported for you, so start coding!

Deze oefening maakt deel uit van de cursus

Intermediate Deep Learning with PyTorch

Cursus bekijken

Oefeninstructies

  • In the .__init__() method, define an LSTM layer and assign it to self.lstm.
  • In the forward() method, initialize the first long-term memory hidden state c0 with zeros.
  • In the forward() method, pass all three inputs to the LSTM layer: the current time step's inputs, and a tuple containing the two hidden states.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

class Net(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        # Define lstm layer
        ____ = ____(
            input_size=1,
            hidden_size=32,
            num_layers=2,
            batch_first=True,
        )
        self.fc = nn.Linear(32, 1)

    def forward(self, x):
        h0 = torch.zeros(2, x.size(0), 32)
        # Initialize long-term memory
        c0 = ____
        # Pass all inputs to lstm layer
        out, _ = ____
        out = self.fc(out[:, -1, :])
        return out
Code bewerken en uitvoeren