IniziaInizia gratis

Creare un modello GRU per il testo

A PyBooks, il team è rimasto colpito dalle prestazioni dei due modelli che hai addestrato in precedenza. Tuttavia, nella loro ricerca dell'eccellenza, vogliono assicurarsi di selezionare il modello assolutamente migliore per il compito da svolgere. Per questo ti hanno chiesto di ampliare ulteriormente il progetto sperimentando le capacità dei modelli GRU, noti per efficienza ed efficacia nei compiti di classificazione del testo. Il tuo nuovo incarico è applicare il modello GRU per classificare gli articoli del dataset Newsgroup nelle seguenti categorie:

rec.autos, sci.med e comp.graphics.

I seguenti pacchetti sono stati già caricati per te: torch, nn, optim.

Questo esercizio fa parte del corso

Deep Learning per il testo con PyTorch

Visualizza il corso

Istruzioni dell'esercizio

  • Completa la classe GRU con i parametri richiesti.
  • Inizializza il modello con gli stessi parametri.
  • Addestra il modello: passa i parametri alla funzione di loss (criterion) ed esegui la backpropagation della loss.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

# Complete the GRU model
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = ____
        self.fc = nn.Linear(hidden_size, num_classes)       
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) 
        out, _ = self.gru(x, h0)
        out = out[:, -1, :] 
        out = self.fc(out)
        return out

# Initialize the model
gru_model = ____
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(gru_model.parameters(), lr=0.01)

# Train the model and backpropagate the loss after initialization
for epoch in range(15): 
    optimizer.zero_grad()
    outputs = ____
    loss = criterion(____, y_train_seq)
    ____
    optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')
Modifica ed esegui il codice