IniziaInizia gratis

Calcolare l'accuracy con torchmetrics

Tenere traccia dell'accuracy durante l'addestramento aiuta a identificare l'epoca con le performance migliori.

In questo esercizio userai torchmetrics per calcolare l'accuracy su un insieme di dati di mascherine con tre classi. La funzione plot_errors metterà in evidenza i campioni classificati in modo errato, aiutandoti ad analizzare gli errori del modello.

Il pacchetto torchmetrics è già importato. Gli outputs del modello sono probabilità softmax e le labels sono vettori one-hot encoded.

Questo esercizio fa parte del corso

Introduzione al Deep Learning con PyTorch

Visualizza il corso

Esercizio pratico interattivo

Prova a risolvere questo esercizio completando il codice di esempio.

# Create accuracy metric
metric = torchmetrics.____(____, ____)
for features, labels in dataloader:
    outputs = model(features)
  
    # Calculate accuracy over the batch
    metric.____(____, ____)
Modifica ed esegui il codice