Valutare i modelli di classificazione RNN
Il team di PyBooks ora vuole che tu valuti il modello RNN che hai creato ed eseguito usando il dataset Newsgroup. Ricorda: l'obiettivo era classificare gli articoli in una delle tre categorie:
rec.autos, sci.med e comp.graphics.
Il modello è stato addestrato e hai stampato epoca e loss per ciascun modello.
Usa torchmetrics per valutare varie metriche per il tuo modello. Sono già state caricate: Accuracy, Precision, Recall, F1Score.
Un'istanza di rnn_model addestrata nell'esercizio precedente è già precaricata per te.
Questo esercizio fa parte del corso
Deep Learning per il testo con PyTorch
Istruzioni dell'esercizio
- Crea un'istanza di ciascuna metrica per la classificazione multiclasse con
num_classesuguale al numero di categorie. - Genera le predizioni per
rnn_modelusando i dati di testX_test_seq. - Calcola le metriche usando le classi previste e le etichette vere.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
# Create an instance of the metrics
accuracy = Accuracy(task="multiclass", ____)
precision = Precision(____, num_classes=____)
recall = Recall(task=____, num_classes=____)
f1 = F1Score(____, ____)
# Generate the predictions
outputs = ____(X_test_seq)
_, predicted = ____.____(outputs, 1)
# Calculate the metrics
accuracy_score = accuracy(____, y_test_seq)
precision_score = precision(____, y_test_seq)
recall_score = recall(____, y_test_seq)
f1_score = f1(____, y_test_seq)
print("RNN Model - Accuracy: {}, Precision: {}, Recall: {}, F1 Score: {}".format(accuracy_score, precision_score, recall_score, f1_score))