1. Decision-Tree for Regression
Welcome back! In this video, you'll learn how to train a decision tree for a regression problem.
Recall that in regression, the target variable is continuous. In other words, the output of your model is a real value.
2. Auto-mpg Dataset
Let's motivate our discussion of regression by introducing the automobile miles-per-gallon dataset from the UCI Machine Learning Repository.
This dataset consists of 6 features corresponding to the characteristics of a car and a continuous target variable labeled mpg which stands for miles-per-gallon. Our task is to predict the mpg consumption of a car given these six features.
To simplify the problem, here the analysis is restricted to only one feature corresponding to the displacement of a car. This feature is denoted by displ.
3. Auto-mpg with one feature
A 2D scatter plot of mpg versus displ shows that the mpg-consumption decreases nonlinearly with displacement. Note that linear models such as linear regression would not be able to capture such a non-linear trend.
Let's see how you can train a decision tree with scikit-learn to solve this regression problem.
4. Regression-Tree in scikit-learn
Note that the features X and the labels y are already loaded in the environment.
First, import DecisionTreeRegressor from sklearn-dot-tree and the functions train_test_split() from sklearn-dot-model_selection and mean_squared_error as MSE() from sklearn-dot-metrics.
Then, split the data into 80%-train and 20%-test using train_test_split.
You can now instantiate the DecisionTreeRegressor() with a maximum depth of 4 by setting the parameter max_depth to 4. In addition, set the parameter min_sample_leaf to 0-dot-1 to impose a stopping condition in which each leaf has to contain at least 10% of the training data.
5. Regression-Tree in scikit-learn
Now fit dt to the training set and predict the test set labels.
To obtain the root-mean-squared-error of your model on the test-set; proceed as follows:
- first, evaluate the mean-squared error,
- then, raise the obtained value to the power 1/2.
Finally, print dt's test set rmse to obtain a value of 5-dot-1.
6. Information Criterion for Regression-Tree
Here, it's important to note that, when a regression tree is trained on a dataset, the impurity of a node is measured using the mean-squared error of the targets in that node.
This means that the regression tree tries to find the splits that produce leafs where in each leaf the target values are on average, the closest possible to the mean-value of the labels in that particular leaf.
7. Prediction
As a new instance traverses the tree and reaches a certain leaf, its target-variable 'y' is computed as the average of the target-variables contained in that leaf as shown in this formula.
8. Linear Regression vs. Regression-Tree
To highlight the importance of the flexibility of regression trees, take a look at this figure.
On the left we have a scatter plot of the data in blue along with the predictions of a linear regression model shown in black. The linear model fails to capture the non-linear trend exhibited by the data.
On the right, we have the same scatter plot along with a red line corresponding to the predictions of the regression tree that you trained earlier. The regression tree shows a greater flexibility and is able to capture the non-linearity, though not fully.
In the next chapter, you'll aggregate the predictions of a set of trees that are trained differently to obtain better results.
9. Let's practice!
Now it's your turn to practice.