Entraîner et tester le modèle Transformer
Avec le modèle TransformerEncoder en place, l’étape suivante chez PyBooks consiste à entraîner le modèle sur des avis d’exemple et à évaluer ses performances. L’entraînement sur ces avis aidera PyBooks à comprendre les tendances de sentiment dans leur vaste référentiel. En obtenant un modèle performant, PyBooks pourra automatiser l’analyse de sentiment, afin que les lecteurs reçoivent des recommandations et des retours pertinents.
Les packages suivants ont été importés pour vous : torch, nn, optim.
L’instance model de la classe TransformerEncoder, token_embeddings, ainsi que train_sentences, train_labels, test_sentences, test_labels sont préchargés pour vous.
Cet exercice fait partie du cours
Deep Learning pour le texte avec PyTorch
Instructions
- Dans la boucle d’entraînement, découpez les phrases en jetons puis empilez les embeddings.
- Remettez les gradients à zéro et effectuez une rétropropagation.
- Dans la fonction
predict, désactivez les calculs de gradient puis récupérez la prédiction de sentiment.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
for epoch in range(5):
for sentence, label in zip(train_sentences, train_labels):
# Split the sentences into tokens and stack the embeddings
tokens = ____
data = torch.____([token_embeddings[token] for token in ____], dim=1)
output = model(data)
loss = criterion(output, torch.tensor([label]))
# Zero the gradients and perform a backward pass
optimizer.____()
loss.____()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item()}")
def predict(sentence):
model.eval()
# Deactivate the gradient computations and get the sentiment prediction.
with torch.____():
tokens = sentence.split()
data = torch.stack([token_embeddings.get(token, torch.rand((1, 512))) for token in tokens], dim=1)
output = model(data)
predicted = torch.____(output, dim=1)
return "Positive" if predicted.item() == 1 else "Negative"
sample_sentence = "This product can be better"
print(f"'{sample_sentence}' is {predict(sample_sentence)}")