Ein CNN-Modell für Text trainieren
Gut gemacht beim Definieren der Klasse TextClassificationCNN. PyBooks muss das Modell jetzt trainieren, um es für eine genaue Sentiment-Analyse von Buchrezensionen zu optimieren.
Die folgenden Pakete wurden für dich importiert:
torch, torch.nn als nn, torch.nn.functional als F, torch.optim als optim.
Eine Instanz von TextClassificationCNN() mit den Argumenten vocab_size und embed_dim wurde ebenfalls geladen und als model gespeichert.
Diese Übung ist Teil des Kurses
Deep Learning für Text mit PyTorch
Anleitung zur Übung
- Definiere eine Loss-Funktion für binäre Klassifikation und speichere sie als
criterion. - Setze die Gradienten zu Beginn der Trainingsschleife auf null.
- Aktualisiere die Parameter am Ende der Schleife.
Interaktive Übung
Vervollständige den Beispielcode, um diese Übung erfolgreich abzuschließen.
# Define the loss function
criterion = nn.____()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for epoch in range(10):
for sentence, label in data:
# Clear the gradients
model.____()
sentence = torch.LongTensor([word_to_ix.get(w, 0) for w in sentence]).unsqueeze(0)
label = torch.LongTensor([int(label)])
outputs = model(sentence)
loss = criterion(outputs, label)
loss.backward()
# Update the parameters
____.____()
print('Training complete!')