Visualiser des arbres XGBoost individuels
Maintenant que vous avez utilisé XGBoost pour entraîner et évaluer des modèles de régression et de classification, prenez le temps d’explorer vos modèles visuellement. Ici, vous allez visualiser des arbres individuels à partir du modèle entièrement boosté qu’XGBoost crée en utilisant l’ensemble complet de données immobilières.
XGBoost propose une fonction plot_tree() qui facilite ce type de visualisation. Une fois que vous avez entraîné un modèle avec l’API d’apprentissage XGBoost, vous pouvez le passer à la fonction plot_tree() avec le nombre d’arbres à tracer via l’argument num_trees.
Cet exercice fait partie du cours
Extreme Gradient Boosting avec XGBoost
Instructions
- Créez un dictionnaire de paramètres avec un
"objective"égal à"reg:squarederror"et un"max_depth"de2. - Entraînez le modèle avec
10itérations de boosting et le dictionnaire de paramètres que vous avez créé. Enregistrez le résultat dansxg_reg. - Tracez le premier arbre avec
xgb.plot_tree(). Cette fonction prend deux arguments : le modèle (icixg_reg) etnum_trees, qui est indexé à partir de 0. Pour tracer le premier arbre, indiquez doncnum_trees=0. - Tracez le cinquième arbre.
- Tracez le dernier (dixième) arbre en mode paysage. Pour cela, indiquez l’argument mot-clé supplémentaire
rankdir="LR".
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Create the DMatrix: housing_dmatrix
housing_dmatrix = xgb.DMatrix(data=X, label=y)
# Create the parameter dictionary: params
params = {"objective":"reg:squarederror", "max_depth":2}
# Train the model: xg_reg
xg_reg = xgb.train(params=params, dtrain=housing_dmatrix, num_boost_round=10)
# Plot the first tree
____
plt.show()
# Plot the fifth tree
____
plt.show()
# Plot the last tree sideways
____
plt.show()