Training und Testen des RNN-Modells mit Attention
Bei PyBooks hatte das Team zuvor ein RNN-Modell zur Wortvorhersage ohne Attention-Mechanismus gebaut. Dieses erste Modell, rnn_model, wurde bereits trainiert und die Instanz ist vorab geladen. Deine Aufgabe ist es nun, das neue RNNWithAttentionModel zu trainieren und seine Vorhersagen mit denen des früheren rnn_model zu vergleichen.
Folgendes ist bereits für dich vorab geladen:
inputs: Liste von Eingabesequenzen als Tensorentargets: Tensor mit Zielwörtern für jede Eingabesequenzoptimizer: Adam-Optimierer-Funktioncriterion: CrossEntropyLoss-Funktionpad_sequences: Funktion zum Auffüllen von Eingabesequenzen für Batchesattention_model: definierte Modellklasse aus der vorherigen Übungrnn_model: trainiertes RNN-Modell vom Team bei PyBooks
Diese Übung ist Teil des Kurses
Deep Learning für Text mit PyTorch
Anleitung zur Übung
- Setze das RNN-Modell vor dem Testen mit den Testdaten in den Evaluierungsmodus.
- Erhalte die RNN-Ausgabe, indem du die passende Eingabe an das RNN-Modell übergibst.
- Extrahiere aus der RNN-Ausgabe das Wort mit dem höchsten Vorhersagewert.
- Mache dasselbe für das Attention-Modell und extrahiere aus der Attention-Ausgabe das Wort mit dem höchsten Vorhersagewert.
Interaktive Übung
Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.
for epoch in range(epochs):
attention_model.train()
optimizer.zero_grad()
padded_inputs = pad_sequences(inputs)
outputs = attention_model(padded_inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
for input_seq, target in zip(input_data, target_data):
input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)
# Set the RNN model to evaluation mode
rnn_model.____()
# Get the RNN output by passing the appropriate input
rnn_output = ____(____)
# Extract the word with the highest prediction score
rnn_prediction = ix_to_word[torch.____(____).item()]
attention_model.eval()
attention_output = attention_model(input_test)
# Extract the word with the highest prediction score
attention_prediction = ix_to_word[torch.____(____).item()]
print(f"\nInput: {' '.join([ix_to_word[ix] for ix in input_seq])}")
print(f"Target: {ix_to_word[target]}")
print(f"RNN prediction: {rnn_prediction}")
print(f"RNN with Attention prediction: {attention_prediction}")