Aan de slagGa gratis aan de slag

Trade-off tussen precisie en recall

Bij classificatietaken kom je vaak de term Precision-Recall trade-off tegen. Waar komt die vandaan?

Meestal kies je de klasse met de hoogste waarschijnlijkheid om het document aan toe te wijzen. Maar wat als de maximale waarschijnlijkheid 0.1 is? Moet je dan vinden dat dit document met slechts 10% kans tot deze klasse behoort?

Het antwoord hangt af van het probleem. Je kunt een minimale drempel instellen om de classificatie te accepteren, en als je die drempel aanpast, bewegen de waarden voor precision en recall in tegengestelde richting.

De variabelen y_true en het model model zijn al geladen. Als de waarschijnlijkheid lager is dan de drempel, wordt het document toegewezen aan DEFAULT_CLASS (gekozen als klasse 2).

Deze oefening maakt deel uit van de cursus

Recurrent Neural Networks (RNN's) voor taalmodellen met Keras

Cursus bekijken

Oefeninstructies

  • Gebruik de .predict()-methode om de waarschijnlijkheden voor elke klasse op te halen en sla ze op in de variabele pred_probabilities.
  • Accepteer de maximale waarschijnlijkheid alleen als die groter dan of gelijk aan 0.5 is en sla de resultaten op in de variabele y_pred_50.
  • Gebruik de functies np.argmax() en np.max() om hetzelfde te doen voor een drempel gelijk aan 0.8.
  • Print de variabele trade_off met alle statistieken.

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# Get probabilities for each class
pred_probabilities = model.____(X_test)

# Thresholds at 0.5 and 0.8
y_pred_50 = [np.argmax(x) if np.max(x) >= ____ else DEFAULT_CLASS for x in pred_probabilities]
y_pred_80 = [np.____(x) if np.____(x) >= 0.8 else DEFAULT_CLASS for x in pred_probabilities]

trade_off = pd.DataFrame({
    'Precision_50': precision_score(y_true, y_pred_50, average=None), 
    'Precision_80': precision_score(y_true, y_pred_80, average=None), 
    'Recall_50': recall_score(y_true, y_pred_50, average=None), 
    'Recall_80': recall_score(y_true, y_pred_80, average=None)}, 
  index=['Class 1', 'Class 2', 'Class 3'])

____
Code bewerken en uitvoeren