Évaluer l'arbre décisionnel
Vous pouvez évaluer la qualité de votre modèle en évaluant ses performances sur les données de test. Étant donné que le modèle n'a pas été formé à partir de ces données, il s'agit d'une évaluation objective du modèle.
Une matrice de confusion fournit une ventilation utile des prédictions par rapport aux valeurs connues. Il comporte quatre cellules qui représentent les nombres suivants :
- Vrais négatifs (TN) — le modèle prédit un résultat négatif et le résultat connu est négatif
- Vrais positifs (TP) — le modèle prédit un résultat positif et le résultat connu est positif
- Faux négatifs (FN) — le modèle prédit un résultat négatif, mais le résultat connu est positif.
- Faux positifs (FP) — le modèle prédit un résultat positif, mais le résultat connu est négatif.
Ces comptes (TN
, TP
, FN
et FP
) doivent correspondre au nombre d'enregistrements dans les données de test, qui ne sont qu'un sous-ensemble des données de vols. Vous pouvez comparer ce nombre au nombre d'enregistrements dans les données de test, qui est flights_test.count()
.
Remarque : Ces prévisions sont établies à partir des données de test, les chiffres sont donc inférieurs à ceux qui auraient été obtenus à partir des données d'entraînement.
Cet exercice fait partie du cours
Apprentissage automatique avec PySpark
Instructions
- Créez une matrice de confusion en comptant les combinaisons de «
label
» et «prediction
». Affichez le résultat. - Veuillez compter le nombre de résultats négatifs vrais, de résultats positifs vrais, de résultats négatifs faux et de résultats positifs faux.
- Veuillez calculer la précision.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Create a confusion matrix
prediction.groupBy(____, 'prediction').____().____()
# Calculate the elements of the confusion matrix
TN = prediction.filter('prediction = 0 AND label = prediction').count()
TP = prediction.____('____ AND ____').____()
FN = prediction.____('____ AND ____').____()
FP = prediction.____('____ AND ____').____()
# Accuracy measures the proportion of correct predictions
accuracy = ____
print(accuracy)