CommencerCommencer gratuitement

Créer un modèle GRU pour le texte

Chez PyBooks, l’équipe a été impressionnée par les performances des deux modèles que vous avez entraînés précédemment. Cependant, dans leur quête d’excellence, elle souhaite s’assurer de sélectionner le meilleur modèle possible pour la tâche. Elle vous a donc demandé d’étendre le projet en expérimentant les capacités des modèles GRU, réputés pour leur efficacité dans les tâches de classification de texte. Votre nouvelle mission consiste à appliquer un modèle GRU pour classer les articles du jeu de données Newsgroup dans les catégories suivantes :

rec.autos, sci.med et comp.graphics.

Les packages suivants ont été chargés pour vous : torch, nn, optim.

Cet exercice fait partie du cours

Deep Learning pour le texte avec PyTorch

Afficher le cours

Instructions

  • Complétez la classe GRU avec les paramètres requis.
  • Initialisez le modèle avec les mêmes paramètres.
  • Entraînez le modèle : passez les paramètres à la fonction de critère et rétropropagez la perte.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

# Complete the GRU model
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = ____
        self.fc = nn.Linear(hidden_size, num_classes)       
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) 
        out, _ = self.gru(x, h0)
        out = out[:, -1, :] 
        out = self.fc(out)
        return out

# Initialize the model
gru_model = ____
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(gru_model.parameters(), lr=0.01)

# Train the model and backpropagate the loss after initialization
for epoch in range(15): 
    optimizer.zero_grad()
    outputs = ____
    loss = criterion(____, y_train_seq)
    ____
    optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')
Modifier et exécuter le code