Aan de slagGa gratis aan de slag

Tekstgeneratie met RNN - Trainen en genereren

Het team van PyBooks wil nu dat je het RNN-model traint en test. Dit model is ontworpen om, op basis van de gegeven invoer, het volgende teken in de reeks te voorspellen voor het automatisch aanvullen van boektitels. Dit project helpt het team om modellen voor tekstaanvulling verder te ontwikkelen.

De model-instance voor de klasse RNNmodel is alvast voor je geladen. De variabele data is voorbewerkt en gecodeerd als een sequentie.

De variabelen inputs en targets zijn alvast voor je geladen.

Deze oefening maakt deel uit van de cursus

Deep Learning voor tekst met PyTorch

Cursus bekijken

Oefeninstructies

  • Initialiseer de verliesfunctie die gebruikt wordt om de fout van ons model te berekenen.
  • Initialiseer de optimizer uit de optimalisatiemodule van PyTorch.
  • Voer het trainingstraject uit door het model in de train-modus te zetten en de gradients te resetten voordat je een optimalisatiestap uitvoert.
  • Zet na het trainen het model in evaluatiemodus om het te testen op een voorbeeldding.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# Instantiate the loss function
criterion = nn.____()
# Instantiate the optimizer
optimizer = torch.optim.____(model.parameters(), lr=0.01)

# Train the model
for epoch in range(100):
    model.____()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    optimizer.____()
    loss.backward()
    optimizer.step()
    if (epoch+1) % 10 == 0:
        print(f'Epoch {epoch+1}/100, Loss: {loss.item()}')

# Test the model
model.____()
test_input = char_to_ix['r']
test_input = nn.functional.one_hot(torch.tensor(test_input).view(-1, 1), num_classes=len(chars)).float()
predicted_output = model(test_input)
predicted_char_ix = torch.argmax(predicted_output, 1).item()
print(f"Test Input: 'r', Predicted Output: '{ix_to_char[predicted_char_ix]}'")
Code bewerken en uitvoeren