LSTM-Netz
Wie du bereits weißt, werden einfache RNN-Zellen in der Praxis nicht sehr häufig verwendet. Eine häufiger verwendete Alternative, die eine viel bessere Verarbeitung langer Sequenzen gewährleistet, sind Long Short-Term Memory-Zellen oder LSTMs. In dieser Aufgabe wirst du selbst ein LSTM-Netz aufbauen!
Der wichtigste Implementierungsunterschied zum RNN-Netz, das du zuvor erstellt hast, ergibt sich aus der Tatsache, dass LSTMs zwei statt einem versteckten Zustand haben. Das bedeutet, dass du diesen zusätzlichen versteckten Zustand initialisieren und an die LSTM-Zelle übergeben musst.
torch
und torch.nn
wurden bereits für dich importiert, also fang an zu programmieren!
Diese Übung ist Teil des Kurses
Deep Learning mit PyTorch für Fortgeschrittene
Anleitung zur Übung
- Definiere in der
.__init__()
-Methode eine LSTM-Schicht und weise sieself.lstm
zu. - Initialisiere in der
forward()
-Methode den ersten versteckten Zustand des Long-Term Memoryc0
mit Nullen. - Übergib in der
forward()
-Methode alle drei Inputs an dieLSTM
-Schicht: die Inputs des aktuellen Zeitschritts und ein Tupel mit den beiden versteckten Zuständen.
Interaktive Übung zum Anfassen
Probieren Sie diese Übung aus, indem Sie diesen Beispielcode ausführen.
class Net(nn.Module):
def __init__(self, input_size):
super().__init__()
# Define lstm layer
____ = ____(
input_size=1,
hidden_size=32,
num_layers=2,
batch_first=True,
)
self.fc = nn.Linear(32, 1)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), 32)
# Initialize long-term memory
c0 = ____
# Pass all inputs to lstm layer
out, _ = ____
out = self.fc(out[:, -1, :])
return out