Avaliação de modelos multiclasse
Vamos avaliar nosso classificador de nuvem com precisão e chamá-lo novamente para ver se ele consegue classificar bem os sete tipos de nuvem. Nessa tarefa de classificação multiclasse, é importante saber como você calcula a média das pontuações das classes. Lembre-se de que há quatro abordagens:
- Não calcular a média e analisar os resultados por classe;
- Fazer a micromédia, ignorando as classes e calculando as métricas globalmente;
- Fazer a macromédia, computando métricas por classe e calculando a média delas;
- Fazer a média ponderada, assim como a macro, mas com a média ponderada pelo tamanho da classe.
Tanto Precision
quanto Recall
já foram importados do torchmetrics
. É hora de ver se o nosso modelo está indo bem!
Este exercício faz parte do curso
Aprendizagem profunda intermediária com PyTorch
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# Define metrics
metric_precision = Precision(task=____, num_classes=____, average=____)
metric_recall = ____
net.eval()
with torch.no_grad():
for images, labels in dataloader_test:
outputs = net(images)
_, preds = torch.max(outputs, 1)
metric_precision(preds, labels)
metric_recall(preds, labels)
precision = metric_precision.compute()
recall = metric_recall.compute()
print(f"Precision: {precision}")
print(f"Recall: {recall}")