Creating a RNN model with attention
At PyBooks, the team has been exploring various deep learning architectures. After some research, you decide to implement an RNN with Attention mechanism to predict the next word in a sentence. You're given a dataset with sentences and a vocabulary created from them.
The following packages have been imported for you: torch, nn.
The following has been preloaded for you:
vocabandvocab_size: The vocabulary set and its sizeword_to_ixandix_to_word: dictionary for word to index and index to word mappingsinput_dataandtarget_data: converted dataset to input-output pairsembedding_dimandhidden_dim: dimensions for embedding and RNN hidden state
You can inspect the data variable in the console to see the example sentences.
Diese Übung ist Teil des Kurses
Deep Learning for Text with PyTorch
Anleitung zur Übung
- Create an embedding layer for the vocabulary with the given
embedding_dim. - Apply a linear transformation to the RNN sequence output to get the attention scores.
- Get the attention weights from the score.
- Compute the context vector as the weighted sum of RNN outputs and attention weights
Interaktive Übung
Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.
class RNNWithAttentionModel(nn.Module):
def __init__(self):
super(RNNWithAttentionModel, self).__init__()
# Create an embedding layer for the vocabulary
self.embeddings = nn.____(vocab_size, embedding_dim)
self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
# Apply a linear transformation to get the attention scores
self.attention = nn.____(____, 1)
self.fc = nn.____(hidden_dim, vocab_size)
def forward(self, x):
x = self.embeddings(x)
out, _ = self.rnn(x)
# Get the attention weights
attn_weights = torch.nn.functional.____(self.____(out).____(2), dim=1)
# Compute the context vector
context = torch.sum(____.____(2) * out, dim=1)
out = self.fc(context)
return out
attention_model = RNNWithAttentionModel()
optimizer = torch.optim.Adam(attention_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
print("Model Instantiated")