IniziaInizia gratis

Valutare i modelli di classificazione RNN

Il team di PyBooks ora vuole che tu valuti il modello RNN che hai creato ed eseguito usando il dataset Newsgroup. Ricorda: l'obiettivo era classificare gli articoli in una delle tre categorie:

rec.autos, sci.med e comp.graphics.

Il modello è stato addestrato e hai stampato epoca e loss per ciascun modello.

Usa torchmetrics per valutare varie metriche per il tuo modello. Sono già state caricate: Accuracy, Precision, Recall, F1Score.

Un'istanza di rnn_model addestrata nell'esercizio precedente è già precaricata per te.

Questo esercizio fa parte del corso

Deep Learning per il testo con PyTorch

Visualizza il corso

Istruzioni dell'esercizio

  • Crea un'istanza di ciascuna metrica per la classificazione multiclasse con num_classes uguale al numero di categorie.
  • Genera le predizioni per rnn_model usando i dati di test X_test_seq.
  • Calcola le metriche usando le classi previste e le etichette vere.

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

# Create an instance of the metrics
accuracy = Accuracy(task="multiclass", ____)
precision = Precision(____, num_classes=____)
recall = Recall(task=____, num_classes=____)
f1 = F1Score(____, ____)

# Generate the predictions
outputs = ____(X_test_seq)
_, predicted = ____.____(outputs, 1)

# Calculate the metrics
accuracy_score = accuracy(____, y_test_seq)
precision_score = precision(____, y_test_seq)
recall_score = recall(____, y_test_seq)
f1_score = f1(____, y_test_seq)
print("RNN Model - Accuracy: {}, Precision: {}, Recall: {}, F1 Score: {}".format(accuracy_score, precision_score, recall_score, f1_score))
Modifica ed esegui il codice