Aan de slagGa gratis aan de slag

Train a CNN model for text

Well done defining the TextClassificationCNN class. PyBooks now needs to train the model to optimize it for accurate sentiment analysis of book reviews.

The following packages have been imported for you: torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim.

An instance of TextClassificationCNN() with arguments vocab_size and embed_dim has also been loaded and saved as model.

Deze oefening maakt deel uit van de cursus

Deep Learning for Text with PyTorch

Cursus bekijken

Oefeninstructies

  • Define a loss function used for binary classification and save as criterion.
  • Zero the gradients at the start of the training loop.
  • Update the parameters at the end of the loop.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# Define the loss function
criterion = nn.____()
optimizer = optim.SGD(model.parameters(), lr=0.1)

for epoch in range(10):
    for sentence, label in data:     
        # Clear the gradients
        model.____()
        sentence = torch.LongTensor([word_to_ix.get(w, 0) for w in sentence]).unsqueeze(0) 
        label = torch.LongTensor([int(label)])
        outputs = model(sentence)
        loss = criterion(outputs, label)
        loss.backward()
        # Update the parameters
        ____.____()
print('Training complete!')
Code bewerken en uitvoeren