Get Started

Visualizing individual XGBoost trees

Now that you've used XGBoost to both build and evaluate regression as well as classification models, you should get a handle on how to visually explore your models. Here, you will visualize individual trees from the fully boosted model that XGBoost creates using the entire housing dataset.

XGBoost has a plot_tree() function that makes this type of visualization easy. Once you train a model using the XGBoost learning API, you can pass it to the plot_tree() function along with the number of trees you want to plot using the num_trees argument.

This is a part of the course

“Extreme Gradient Boosting with XGBoost”

View Course

Exercise instructions

  • Create a parameter dictionary with an "objective" of "reg:squarederror" and a "max_depth" of 2.
  • Train the model using 10 boosting rounds and the parameter dictionary you created. Save the result in xg_reg.
  • Plot the first tree using xgb.plot_tree(). It takes in two arguments - the model (in this case, xg_reg), and num_trees, which is 0-indexed. So to plot the first tree, specify num_trees=0.
  • Plot the fifth tree.
  • Plot the last (tenth) tree sideways. To do this, specify the additional keyword argument rankdir="LR".

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

# Create the DMatrix: housing_dmatrix
housing_dmatrix = xgb.DMatrix(data=X, label=y)

# Create the parameter dictionary: params
params = {"objective":"reg:squarederror", "max_depth":2}

# Train the model: xg_reg
xg_reg = xgb.train(params=params, dtrain=housing_dmatrix, num_boost_round=10)

# Plot the first tree
____
plt.show()

# Plot the fifth tree
____
plt.show()

# Plot the last tree sideways
____
plt.show()
Edit and Run Code