Calcolare l'accuracy con torchmetrics
Tenere traccia dell'accuracy durante l'addestramento aiuta a identificare l'epoca con le performance migliori.
In questo esercizio userai torchmetrics per calcolare l'accuracy su un insieme di dati di mascherine con tre classi. La funzione plot_errors metterà in evidenza i campioni classificati in modo errato, aiutandoti ad analizzare gli errori del modello.
Il pacchetto torchmetrics è già importato. Gli outputs del modello sono probabilità softmax e le labels sono vettori one-hot encoded.
Questo esercizio fa parte del corso
Introduzione al Deep Learning con PyTorch
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
# Create accuracy metric
metric = torchmetrics.____(____, ____)
for features, labels in dataloader:
outputs = model(features)
# Calculate accuracy over the batch
metric.____(____, ____)