Linear base learners
Now that you've used trees as base models in XGBoost, let's use the other kind of base model that can be used with XGBoost - a linear learner. This model, although not as commonly used in XGBoost, allows you to create a regularized linear regression using XGBoost's powerful learning API. However, because it's uncommon, you have to use XGBoost's own non-scikit-learn compatible functions to build the model, such as xgb.train()
.
In order to do this you must create the parameter dictionary that describes the kind of booster you want to use (similarly to how you created the dictionary in Chapter 1 when you used xgb.cv()
). The key-value pair that defines the booster type (base model) you need is "booster":"gblinear"
.
Once you've created the model, you can use the .train()
and .predict()
methods of the model just like you've done in the past.
Here, the data has already been split into training and testing sets, so you can dive right into creating the DMatrix
objects required by the XGBoost learning API.
This exercise is part of the course
Extreme Gradient Boosting with XGBoost
Exercise instructions
- Create two
DMatrix
objects -DM_train
for the training set (X_train
andy_train
), andDM_test
(X_test
andy_test
) for the test set. - Create a parameter dictionary that defines the
"booster"
type you will use ("gblinear"
) as well as the"objective"
you will minimize ("reg:squarederror"
). - Train the model using
xgb.train()
. You have to specify arguments for the following parameters:params
,dtrain
, andnum_boost_round
. Use5
boosting rounds. - Predict the labels on the test set using
xg_reg.predict()
, passing itDM_test
. Assign topreds
. - Hit 'Submit Answer' to view the RMSE!
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Convert the training and testing sets into DMatrixes: DM_train, DM_test
DM_train = ____
DM_test = ____
# Create the parameter dictionary: params
params = {"____":"____", "____":"____"}
# Train the model: xg_reg
xg_reg = ____.____(____ = ____, ____=____, ____=____)
# Predict the labels of the test set: preds
preds = ____
# Compute and print the RMSE
rmse = np.sqrt(mean_squared_error(y_test,preds))
print("RMSE: %f" % (rmse))