Training and testing the RNN model with attention
At PyBooks, the team had previously built an RNN model for word prediction without the attention mechanism. This initial model, referred to as rnn_model
, has already been trained and its instance is preloaded. Your task now is to train the new RNNWithAttentionModel
and compare its predictions with that of the earlier rnn_model
.
The following has been preloaded for you:
inputs
: list of input sequences as tensorstargets
: tensor containing target words for each input sequenceoptimizer
: Adam optimizer functioncriterion
: CrossEntropyLoss functionpad_sequences
: function to pad input sequences for batchingattention_model
: defined model class from the previous exercisernn_model
:trained RNN model from the team at PyBooks
This exercise is part of the course
Deep Learning for Text with PyTorch
Exercise instructions
- Set the RNN model to evaluation mode before testing it with the test data.
- Get the RNN output by passing the appropriate input to the RNN model.
- Extract the word with the highest prediction score from the RNN output.
- Similarly, for the attention model, extract the word with the highest prediction score from the attention output.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
for epoch in range(epochs):
attention_model.train()
optimizer.zero_grad()
padded_inputs = pad_sequences(inputs)
outputs = attention_model(padded_inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
for input_seq, target in zip(input_data, target_data):
input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)
# Set the RNN model to evaluation mode
rnn_model.____()
# Get the RNN output by passing the appropriate input
rnn_output = ____(____)
# Extract the word with the highest prediction score
rnn_prediction = ix_to_word[torch.____(____).item()]
attention_model.eval()
attention_output = attention_model(input_test)
# Extract the word with the highest prediction score
attention_prediction = ix_to_word[torch.____(____).item()]
print(f"\nInput: {' '.join([ix_to_word[ix] for ix in input_seq])}")
print(f"Target: {ix_to_word[target]}")
print(f"RNN prediction: {rnn_prediction}")
print(f"RNN with Attention prediction: {attention_prediction}")