Validation croisée du pipeline du modèle de durée de vol
Le modèle validé par croisé que vous venez de construire était simple, utilisant uniquement l'km
pour prédire l'duration
.
Un autre facteur important permettant de prédire la durée d'un vol est l'aéroport de départ. Les vols décollent généralement plus tard depuis les aéroports très fréquentés. Voyons si l'ajout de ce prédicteur améliore le modèle.
Dans cet exercice, vous allez ajouter le champ « org
» au modèle. Cependant, étant donné que org
est catégorique, il reste encore du travail à accomplir avant de pouvoir l'inclure : il doit d'abord être transformé en index, puis encodé en one-hot avant d'être assemblé avec km
et utilisé pour construire le modèle de régression. Nous allons regrouper ces opérations dans un pipeline.
Les objets suivants ont déjà été créés :
params
— une grille de paramètres videevaluator
— un évaluateur de régressionregression
— un objet d'LinearRegression
aveclabelCol='duration'
.
Les classes StringIndexer
, OneHotEncoder
, VectorAssembler
et CrossValidator
ont déjà été importées.
Cet exercice fait partie du cours
Apprentissage automatique avec PySpark
Instructions
- Créez un indexeur de chaîne. Veuillez spécifier les champs d'entrée et de sortie comme suit :
org
etorg_idx
. - Créez un encodeur one-hot. Nommez le champ de sortie «
org_dummy
». - Veuillez regrouper les champs «
km
» et «org_dummy
» en un seul champ appelé «features
». - Veuillez créer un pipeline à l'aide des opérations suivantes : indexeur de chaîne, encodeur one-hot, assembleur et régression linéaire. Veuillez utiliser ceci pour créer un validateur croisé.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Create an indexer for the org field
indexer = ____(____, ____)
# Create an one-hot encoder for the indexed org field
onehot = ____(____, ____)
# Assemble the km and one-hot encoded fields
assembler = ____(____, ____)
# Create a pipeline and cross-validator.
pipeline = ____(stages=[____, ____, ____, ____])
cv = ____(estimator=____,
estimatorParamMaps=____,
evaluator=____)