Aan de slagGa gratis aan de slag

Een GRU-model bouwen voor tekst

Bij PyBooks is het team onder de indruk van de prestaties van de twee modellen die je eerder trainde. In hun streven naar topkwaliteit willen ze echter zeker weten dat ze het allerbeste model voor deze taak kiezen. Daarom hebben ze je gevraagd het project uit te breiden door te experimenteren met de mogelijkheden van GRU-modellen, die bekendstaan om hun efficiëntie en effectiviteit bij tekstonthoudingstaken. Je nieuwe opdracht is om het GRU-model toe te passen om artikelen uit de Newsgroup-gegevensset te classificeren in de volgende categorieën:

rec.autos, sci.med en comp.graphics.

De volgende pakketten zijn alvast voor je geladen: torch, nn, optim.

Deze oefening maakt deel uit van de cursus

Deep Learning voor tekst met PyTorch

Cursus bekijken

Oefeninstructies

  • Vul de GRU-klasse aan met de vereiste parameters.
  • Initialiseer het model met dezelfde parameters.
  • Train het model: geef de parameters door aan de criteriumfunctie en voer backpropagation uit op het verlies.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# Complete the GRU model
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = ____
        self.fc = nn.Linear(hidden_size, num_classes)       
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) 
        out, _ = self.gru(x, h0)
        out = out[:, -1, :] 
        out = self.fc(out)
        return out

# Initialize the model
gru_model = ____
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(gru_model.parameters(), lr=0.01)

# Train the model and backpropagate the loss after initialization
for epoch in range(15): 
    optimizer.zero_grad()
    outputs = ____
    loss = criterion(____, y_train_seq)
    ____
    optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')
Code bewerken en uitvoeren