LSTM network
As you already know, plain RNN cells are not used that much in practice. A more frequently used alternative that ensures a much better handling of long sequences are Long Short-Term Memory cells, or LSTMs. In this exercise, you will be build an LSTM network yourself!
The most important implementation difference from the RNN network you have built previously comes from the fact that LSTMs have two rather than one hidden states. This means you will need to initialize this additional hidden state and pass it to the LSTM cell.
torch and torch.nn have already been imported for you, so start coding!
Questo esercizio fa parte del corso
Intermediate Deep Learning with PyTorch
Istruzioni dell'esercizio
- In the
.__init__()method, define an LSTM layer and assign it toself.lstm. - In the
forward()method, initialize the first long-term memory hidden statec0with zeros. - In the
forward()method, pass all three inputs to theLSTMlayer: the current time step's inputs, and a tuple containing the two hidden states.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
class Net(nn.Module):
def __init__(self, input_size):
super().__init__()
# Define lstm layer
____ = ____(
input_size=1,
hidden_size=32,
num_layers=2,
batch_first=True,
)
self.fc = nn.Linear(32, 1)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), 32)
# Initialize long-term memory
c0 = ____
# Pass all inputs to lstm layer
out, _ = ____
out = self.fc(out[:, -1, :])
return out