Transformer-Modell trainieren und testen
Nachdem das TransformerEncoder-Modell steht, ist der nächste Schritt bei PyBooks, das Modell mit Beispielrezensionen zu trainieren und seine Leistung zu bewerten. Das Training auf diesen Beispielrezensionen hilft PyBooks, Stimmungstrends im großen Bestand zu verstehen. Mit einem gut funktionierenden Modell kann PyBooks die Sentiment-Analyse automatisieren und Leserinnen und Lesern dadurch treffende Empfehlungen und Feedback liefern.
Die folgenden Pakete wurden für dich importiert: torch, nn, optim.
Die model-Instanz der Klasse TransformerEncoder, token_embeddings sowie train_sentences, train_labels, test_sentences, test_labels sind für dich vorkonfiguriert.
Diese Übung ist Teil des Kurses
Deep Learning für Text mit PyTorch
Anleitung zur Übung
- Teile in der Trainingsschleife die Sätze in Tokens auf und stapel die Embeddings.
- Setze die Gradienten auf null und führe einen Backward-Pass aus.
- Deaktiviere in der Funktion
predictdie Gradientenberechnungen und ermittle dann die Sentiment-Vorhersage.
Interaktive Übung
Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.
for epoch in range(5):
for sentence, label in zip(train_sentences, train_labels):
# Split the sentences into tokens and stack the embeddings
tokens = ____
data = torch.____([token_embeddings[token] for token in ____], dim=1)
output = model(data)
loss = criterion(output, torch.tensor([label]))
# Zero the gradients and perform a backward pass
optimizer.____()
loss.____()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item()}")
def predict(sentence):
model.eval()
# Deactivate the gradient computations and get the sentiment prediction.
with torch.____():
tokens = sentence.split()
data = torch.stack([token_embeddings.get(token, torch.rand((1, 512))) for token in tokens], dim=1)
output = model(data)
predicted = torch.____(output, dim=1)
return "Positive" if predicted.item() == 1 else "Negative"
sample_sentence = "This product can be better"
print(f"'{sample_sentence}' is {predict(sample_sentence)}")