CommencerCommencer gratuitement

Compromis précision-rappel

Dans les tâches de classification, on rencontre souvent le terme compromis précision-rappel. D’où vient-il ?

En général, on choisit la classe ayant la probabilité la plus élevée pour affecter un document. Mais que faire si la probabilité maximale vaut 0.1 ? Devriez-vous considérer que ce document appartient à cette classe avec seulement 10 % de probabilité ?

La réponse dépend du problème. Il est possible de fixer un seuil minimal pour accepter la classification, et lorsque vous modifiez ce seuil, les valeurs de précision et de rappel évoluent en sens opposé.

Les variables y_true et le modèle model sont chargés. De plus, si la probabilité est inférieure au seuil, le document sera affecté à DEFAULT_CLASS (choisie comme la classe 2).

Cet exercice fait partie du cours

Réseaux de neurones récurrents (RNN) pour la modélisation du langage avec Keras

Afficher le cours

Instructions

  • Utilisez la méthode .predict() pour obtenir les probabilités de chaque classe et stockez-les dans la variable pred_probabilities.
  • Acceptez la probabilité maximale uniquement si elle est supérieure ou égale à 0.5 et stockez les résultats dans la variable y_pred_50.
  • Utilisez les fonctions np.argmax() et np.max() pour faire la même chose avec un seuil égal à 0.8.
  • Affichez la variable trade_off avec toutes les métriques.

Exercice interactif pratique

Essayez cet exercice en complétant cet exemple de code.

# Get probabilities for each class
pred_probabilities = model.____(X_test)

# Thresholds at 0.5 and 0.8
y_pred_50 = [np.argmax(x) if np.max(x) >= ____ else DEFAULT_CLASS for x in pred_probabilities]
y_pred_80 = [np.____(x) if np.____(x) >= 0.8 else DEFAULT_CLASS for x in pred_probabilities]

trade_off = pd.DataFrame({
    'Precision_50': precision_score(y_true, y_pred_50, average=None), 
    'Precision_80': precision_score(y_true, y_pred_80, average=None), 
    'Recall_50': recall_score(y_true, y_pred_50, average=None), 
    'Recall_80': recall_score(y_true, y_pred_80, average=None)}, 
  index=['Class 1', 'Class 2', 'Class 3'])

____
Modifier et exécuter le code