LSTM network
As you already know, plain RNN cells are not used that much in practice. A more frequently used alternative that ensures a much better handling of long sequences are Long Short-Term Memory cells, or LSTMs. In this exercise, you will be build an LSTM network yourself!
The most important implementation difference from the RNN network you have built previously comes from the fact that LSTMs have two rather than one hidden states. This means you will need to initialize this additional hidden state and pass it to the LSTM cell.
torch
and torch.nn
have already been imported for you, so start coding!
This exercise is part of the course
Intermediate Deep Learning with PyTorch
Exercise instructions
- In the
.__init__()
method, define an LSTM layer and assign it toself.lstm
. - In the
forward()
method, initialize the first long-term memory hidden statec0
with zeros. - In the
forward()
method, pass all three inputs to theLSTM
layer: the current time step's inputs, and a tuple containing the two hidden states.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
class Net(nn.Module):
def __init__(self, input_size):
super().__init__()
# Define lstm layer
____ = ____(
input_size=1,
hidden_size=32,
num_layers=2,
batch_first=True,
)
self.fc = nn.Linear(32, 1)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), 32)
# Initialize long-term memory
c0 = ____
# Pass all inputs to lstm layer
out, _ = ____
out = self.fc(out[:, -1, :])
return out