Een RNN-model met attention maken
Bij PyBooks verkent het team verschillende deep learning-architecturen. Na wat onderzoek besluit je een RNN met een attention-mechanisme te implementeren om het volgende woord in een zin te voorspellen. Je krijgt een gegevensset met zinnen en een daaruit opgebouwde woordenschat.
De volgende pakketten zijn alvast voor je geïmporteerd: torch, nn.
Het volgende is voor je vooraf geladen:
vocabenvocab_size: de woordenschat en de grootte ervanword_to_ixenix_to_word: woordenboek voor koppelingen van woord-naar-index en index-naar-woordinput_dataentarget_data: de gegevensset omgezet naar input-outputparenembedding_dimenhidden_dim: dimensies voor de embedding en de verborgen toestand van de RNN
Je kunt de variabele data in de console bekijken om de voorbeeldzinnen te zien.
Deze oefening maakt deel uit van de cursus
Deep Learning voor tekst met PyTorch
Oefeninstructies
- Maak een embedding-laag voor de woordenschat met de gegeven
embedding_dim. - Pas een lineaire transformatie toe op de RNN-sequentie-uitvoer om de attentiescores te krijgen.
- Haal de aandachtgewichten uit de score.
- Bereken de contextvector als de gewogen som van de RNN-uitgangen en de aandachtgewichten.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
class RNNWithAttentionModel(nn.Module):
def __init__(self):
super(RNNWithAttentionModel, self).__init__()
# Create an embedding layer for the vocabulary
self.embeddings = nn.____(vocab_size, embedding_dim)
self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
# Apply a linear transformation to get the attention scores
self.attention = nn.____(____, 1)
self.fc = nn.____(hidden_dim, vocab_size)
def forward(self, x):
x = self.embeddings(x)
out, _ = self.rnn(x)
# Get the attention weights
attn_weights = torch.nn.functional.____(self.____(out).____(2), dim=1)
# Compute the context vector
context = torch.sum(____.____(2) * out, dim=1)
out = self.fc(context)
return out
attention_model = RNNWithAttentionModel()
optimizer = torch.optim.Adam(attention_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
print("Model Instantiated")