Get startedGet started for free

Linear base learners

Now that you've used trees as base models in XGBoost, let's use the other kind of base model that can be used with XGBoost - a linear learner. This model, although not as commonly used in XGBoost, allows you to create a regularized linear regression using XGBoost's powerful learning API. However, because it's uncommon, you have to use XGBoost's own non-scikit-learn compatible functions to build the model, such as xgb.train().

In order to do this you must create the parameter dictionary that describes the kind of booster you want to use (similarly to how you created the dictionary in Chapter 1 when you used xgb.cv()). The key-value pair that defines the booster type (base model) you need is "booster":"gblinear".

Once you've created the model, you can use the .train() and .predict() methods of the model just like you've done in the past.

Here, the data has already been split into training and testing sets, so you can dive right into creating the DMatrix objects required by the XGBoost learning API.

This exercise is part of the course

Extreme Gradient Boosting with XGBoost

View Course

Exercise instructions

  • Create two DMatrix objects - DM_train for the training set (X_train and y_train), and DM_test (X_test and y_test) for the test set.
  • Create a parameter dictionary that defines the "booster" type you will use ("gblinear") as well as the "objective" you will minimize ("reg:squarederror").
  • Train the model using xgb.train(). You have to specify arguments for the following parameters: params, dtrain, and num_boost_round. Use 5 boosting rounds.
  • Predict the labels on the test set using xg_reg.predict(), passing it DM_test. Assign to preds.
  • Hit 'Submit Answer' to view the RMSE!

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

# Convert the training and testing sets into DMatrixes: DM_train, DM_test
DM_train = ____
DM_test =  ____

# Create the parameter dictionary: params
params = {"____":"____", "____":"____"}

# Train the model: xg_reg
xg_reg = ____.____(____ = ____, ____=____, ____=____)

# Predict the labels of the test set: preds
preds = ____

# Compute and print the RMSE
rmse = np.sqrt(mean_squared_error(y_test,preds))
print("RMSE: %f" % (rmse))
Edit and Run Code