Get startedGet started for free

Cross-validation

1. Cross-validation

Welcome back! So far, we always used a single train/test split. However, this is a little fragile: a single outlier can vastly change our out-of-sample error. One way to reduce this variance is to average multiple estimates together. This is exactly what cross-validation does.

2. k-fold cross-validation

How exactly does it work? The first step is to partition the rows of the training dataset into k subsets of equal sizes. So, if you have 500 data points and you have 5 folds, then there will be 100 data points in each of the folds.

3. k-fold cross-validation

In each iteration, you pick one of the k subsets as your test set

4. k-fold cross-validation

and the remaining k minus 1 subsets are used as the aggregated training set.

5. k-fold cross-validation

6. k-fold cross-validation

You train your machine learning model on each training set and evaluate the model's performance on each test set.

7. k-fold cross-validation

Once you are finished you will have 5 estimates of the out-of-sample error.

8. k-fold cross-validation

We average those 5 estimates together to get what's called the cross-validated estimate of the error. Since we end up training k models instead of one, it's obvious that cross-validation takes k times as long to evaluate your models this way.

9. Fit final model on the full dataset

One important note: You use cross-validation to estimate the out-of-sample error for your model. When finished, you throw all resampled models away and start over, fitting your model on the full training dataset, to fully exploit the information in that dataset.

10. Coding - Split the data 10 times

Using the tidymodels package, it's incredibly easy to perform cross-validation. The relevant function is called vfold_cv(). It expects the data to be split and v, the number of folds to create. Let's create ten folds of chocolate_train. The result is a tibble with ten rows, each row containing one split together with an id of that fold.

11. Coding - Fit the folds

Next, we want to train a model for every fold and measure the out-of-sample performance for that model. tidymodels gives us a shortcut called fit_resamples(). It takes the tree specification, tree_spec, the model formula, the resamples, or folds, chocolate_folds, and metrics that you want to assess. The metrics are bundled like a list using the metric_set() function. You already know MAE and RMSE. The result is the previous tibble with a new column dot-metrics containing the out-of-sample errors for every fold. Every result in the dot-metrics column has two rows because we asked for MAE and RMSE in the metric_set.

12. Coding - Collect all errors

There is a handy function that extracts all errors from the fitting results. collect_metrics() takes your CV results and an argument summarize, which specifies if you want to calculate summary statistics. The result is a tibble containing every error of every fold. Let's draw a histogram using ggplot2 to visualize these errors. We use the dot-estimate as the x aesthetic and dot-metric as the fill variable. Without cross-validation, we would have only one of the red bars, that is one mean absolute error, and only one of the blue bars, that is one root mean squared error. See how useful cross-validation is for estimating model performance?

13. Coding - Summarize training sessions

Of course, you can specify summarize equals TRUE, which is the default in collect_metrics(). This results in a small tibble showing the name and the mean of the metric, and n, the number of errors that were calculated. n here equals 10, because we had 10 folds in our cross-validation, and the mean absolute out-of-sample error in this example is 0-point-383.

14. Let's cross-validate!

Let's practice cross-validating your model.