Train a CNN model for text
Well done defining the TextClassificationCNN class. PyBooks now needs to train the model to optimize it for accurate sentiment analysis of book reviews.
The following packages have been imported for you:
torch
, torch.nn
as nn
, torch.nn.functional
as F
, torch.optim
as optim
.
An instance of TextClassificationCNN()
with arguments vocab_size
and embed_dim
has also been loaded and saved as model
.
Este ejercicio forma parte del curso
Deep Learning for Text with PyTorch
Instrucciones del ejercicio
- Define a loss function used for binary classification and save as
criterion
. - Zero the gradients at the start of the training loop.
- Update the parameters at the end of the loop.
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
# Define the loss function
criterion = nn.____()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for epoch in range(10):
for sentence, label in data:
# Clear the gradients
model.____()
sentence = torch.LongTensor([word_to_ix.get(w, 0) for w in sentence]).unsqueeze(0)
label = torch.LongTensor([int(label)])
outputs = model(sentence)
loss = criterion(outputs, label)
loss.backward()
# Update the parameters
____.____()
print('Training complete!')