Aan de slagGa gratis aan de slag

RNN-classificatiemodellen evalueren

Het team van PyBooks wil nu dat je het RNN-model beoordeelt dat je hebt gebouwd en uitgevoerd met de Newsgroup-gegevensset. Weet je nog: het doel was om de artikelen in een van drie categorieën in te delen:

rec.autos, sci.med en comp.graphics.

Het model is getraind en je hebt voor elk epoch de loss geprint.

Gebruik torchmetrics om verschillende metrics voor je model te evalueren. Het volgende is voor je geladen: Accuracy, Precision, Recall, F1Score.

Een instantie van rnn_model die in de vorige oefening is getraind, is ook voor je vooringeladen.

Deze oefening maakt deel uit van de cursus

Deep Learning voor tekst met PyTorch

Cursus bekijken

Oefeninstructies

  • Maak een instantie van elke metric voor multi-class-classificatie met num_classes gelijk aan het aantal categorieën.
  • Genereer de voorspellingen van het rnn_model met de testgegevens X_test_seq.
  • Bereken de metrics met de voorspelde klassen en de echte labels.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# Create an instance of the metrics
accuracy = Accuracy(task="multiclass", ____)
precision = Precision(____, num_classes=____)
recall = Recall(task=____, num_classes=____)
f1 = F1Score(____, ____)

# Generate the predictions
outputs = ____(X_test_seq)
_, predicted = ____.____(outputs, 1)

# Calculate the metrics
accuracy_score = accuracy(____, y_test_seq)
precision_score = precision(____, y_test_seq)
recall_score = recall(____, y_test_seq)
f1_score = f1(____, y_test_seq)
print("RNN Model - Accuracy: {}, Precision: {}, Recall: {}, F1 Score: {}".format(accuracy_score, precision_score, recall_score, f1_score))
Code bewerken en uitvoeren