Visualizing individual XGBoost trees
Now that you've used XGBoost to both build and evaluate regression as well as classification models, you should get a handle on how to visually explore your models. Here, you will visualize individual trees from the fully boosted model that XGBoost creates using the entire housing dataset.
XGBoost has a plot_tree()
function that makes this type of visualization easy. Once you train a model using the XGBoost learning API, you can pass it to the plot_tree()
function along with the number of trees you want to plot using the num_trees
argument.
This is a part of the course
“Extreme Gradient Boosting with XGBoost”
Exercise instructions
- Create a parameter dictionary with an
"objective"
of"reg:squarederror"
and a"max_depth"
of2
. - Train the model using
10
boosting rounds and the parameter dictionary you created. Save the result inxg_reg
. - Plot the first tree using
xgb.plot_tree()
. It takes in two arguments - the model (in this case,xg_reg
), andnum_trees
, which is 0-indexed. So to plot the first tree, specifynum_trees=0
. - Plot the fifth tree.
- Plot the last (tenth) tree sideways. To do this, specify the additional keyword argument
rankdir="LR"
.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Create the DMatrix: housing_dmatrix
housing_dmatrix = xgb.DMatrix(data=X, label=y)
# Create the parameter dictionary: params
params = {"objective":"reg:squarederror", "max_depth":2}
# Train the model: xg_reg
xg_reg = xgb.train(params=params, dtrain=housing_dmatrix, num_boost_round=10)
# Plot the first tree
____
plt.show()
# Plot the fifth tree
____
plt.show()
# Plot the last tree sideways
____
plt.show()