RNN-model met attention trainen en testen
Bij PyBooks heeft het team eerder een RNN-model voor woordvoorspelling gebouwd zonder het attention-mechanisme. Dit eerste model, rnn_model genoemd, is al getraind en het exemplaar is vooraf ingeladen. Jouw taak is nu om het nieuwe RNNWithAttentionModel te trainen en de voorspellingen te vergelijken met die van het eerdere rnn_model.
Het volgende is voor je ingeladen:
inputs: lijst met invoersequenties als tensorstargets: tensor met doelwoorden voor elke invoersequentieoptimizer: Adam-optimizerfunctiecriterion: CrossEntropyLoss-functiepad_sequences: functie om invoersequenties te padden voor batchingattention_model: gedefinieerde modelklasse uit de vorige oefeningrnn_model: getraind RNN-model van het team bij PyBooks
Deze oefening maakt deel uit van de cursus
Deep Learning voor tekst met PyTorch
Oefeninstructies
- Zet het RNN-model in evaluatiemodus voordat je het met de testgegevens test.
- Haal de RNN-output op door de juiste invoer aan het RNN-model door te geven.
- Haal uit de RNN-output het woord met de hoogste voorspellingsscore.
- Doe voor het attention-model hetzelfde: haal uit de attention-output het woord met de hoogste voorspellingsscore.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
for epoch in range(epochs):
attention_model.train()
optimizer.zero_grad()
padded_inputs = pad_sequences(inputs)
outputs = attention_model(padded_inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
for input_seq, target in zip(input_data, target_data):
input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)
# Set the RNN model to evaluation mode
rnn_model.____()
# Get the RNN output by passing the appropriate input
rnn_output = ____(____)
# Extract the word with the highest prediction score
rnn_prediction = ix_to_word[torch.____(____).item()]
attention_model.eval()
attention_output = attention_model(input_test)
# Extract the word with the highest prediction score
attention_prediction = ix_to_word[torch.____(____).item()]
print(f"\nInput: {' '.join([ix_to_word[ix] for ix in input_seq])}")
print(f"Target: {ix_to_word[target]}")
print(f"RNN prediction: {rnn_prediction}")
print(f"RNN with Attention prediction: {attention_prediction}")