Rete LSTM
Come già sai, le celle RNN semplici non sono molto usate in pratica. Un'alternativa più comune, che gestisce molto meglio le sequenze lunghe, sono le Long Short-Term Memory, o LSTM. In questo esercizio costruirai tu una rete LSTM!
La differenza di implementazione più importante rispetto alla rete RNN che hai creato prima è che le LSTM hanno due stati nascosti invece di uno. Questo significa che dovrai inizializzare questo stato nascosto aggiuntivo e passarlo alla cella LSTM.
torch e torch.nn sono già stati importati per te, quindi puoi iniziare a scrivere il codice!
Questo esercizio fa parte del corso
Deep Learning intermedio con PyTorch
Istruzioni dell'esercizio
- Nel metodo
.__init__(), definisci un livello LSTM e assegnalo aself.lstm. - Nel metodo
forward(), inizializza con zeri il primo stato nascosto della memoria a lungo terminec0. - Nel metodo
forward(), passa tutti e tre gli input al livelloLSTM: gli input dell'istante corrente e una tupla che contiene i due stati nascosti.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
class Net(nn.Module):
def __init__(self, input_size):
super().__init__()
# Define lstm layer
____ = ____(
input_size=1,
hidden_size=32,
num_layers=2,
batch_first=True,
)
self.fc = nn.Linear(32, 1)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), 32)
# Initialize long-term memory
c0 = ____
# Pass all inputs to lstm layer
out, _ = ____
out = self.fc(out[:, -1, :])
return out