Évaluer l'arbre de décision
Vous pouvez évaluer la qualité de votre modèle en évaluant ses performances sur les données de test. Le modèle n'ayant pas été entraîné sur ces données, il s'agit d'une évaluation objective du modèle.
Une matrice de confusion donne une ventilation utile des prédictions par rapport aux valeurs connues. Il comporte quatre cellules qui représentent les nombres de :
- Vrais négatifs (TN) - le modèle prédit un résultat négatif et le résultat connu est négatif.
- Vrais positifs (TP) - le modèle prédit un résultat positif et le résultat connu est positif.
- Faux négatifs (FN) - le modèle prédit un résultat négatif alors que le résultat connu est positif.
- Faux positifs (FP) - le modèle prédit un résultat positif alors que le résultat connu est négatif.
Ces chiffres (TN, TP, FN et FP) doivent correspondre au nombre d'enregistrements dans les données de test, qui ne sont qu'un sous-ensemble des données de vol. Vous pouvez comparer avec le nombre d'enregistrements dans les données des tests, qui est de flights_test.count().
Remarque : Ces prédictions sont effectuées sur les données de test, de sorte que les chiffres sont inférieurs à ce qu'ils auraient été pour des prédictions sur les données d'apprentissage.
Cet exercice fait partie du cours
Apprentissage automatique avec PySpark
Instructions
- Créez une matrice de confusion en comptant les combinaisons de
labeletprediction. Affichez le résultat. - Comptez le nombre de vrais négatifs, de vrais positifs, de faux négatifs et de faux positifs.
- Calculez la précision.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Create a confusion matrix
prediction.groupBy(____, 'prediction').____().____()
# Calculate the elements of the confusion matrix
TN = prediction.filter('prediction = 0 AND label = prediction').count()
TP = prediction.____('____ AND ____').____()
FN = prediction.____('____ AND ____').____()
FP = prediction.____('____ AND ____').____()
# Accuracy measures the proportion of correct predictions
accuracy = ____
print(accuracy)