Get startedGet started for free

Gradient boosted trees: prediction

Once you've run your model, then the next step is to make a prediction with it. In contrast with base-R which uses the predict() function to make predictions, sparklyr uses ml_predict() function.ml_predict() takes two arguments: a model, and some testing data.

ml_predict(a_model, testing_data)

A common use case is to compare the predicted responses with the actual responses, which you can draw plots of in R. The code pattern for preparing this data is as follows. Note that currently adding a prediction column has to be done locally, so you must collect the results first.

predicted_vs_actual <- testing_data %>%
  select(actual) %>%
  collect() %>%
  mutate(predicted)

This exercise is part of the course

Introduction to Spark with sparklyr in R

View Course

Exercise instructions

A Spark connection has been created for you as spark_conn. Tibbles attached to the training and testing datasets stored in Spark have been pre-defined as track_data_to_model_tbl and track_data_to_predict_tbl respectively. The gradient boosted trees model has been pre-defined as gradient_boosted_trees_model.

  • Define a variable predicted that contains the model's predictions for our testing data.
    • Call ml_predict() with the model and the testing data as arguments. This function will generate predictions for the testing dataset and add these as a new column named prediction.
    • Using pull(), we can extract this column and assign it to predicted.
  • Define the responses variable to prepare the data for comparing predicted responses with actual responses:
    • Select the response column year.
    • Collect the results.
    • Use mutate() to add in the predictions made in predicted.

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

# Training, testing sets & model are pre-defined
track_data_to_model_tbl
track_data_to_predict_tbl
gradient_boosted_trees_model

# Predict the responses for the testing data
predicted <- ___(
      ___,
      ___) %>% pull(prediction)

# Prepare the data for comparing predicted responses with actual responses
responses <- track_data_to_predict_tbl %>%
  # Select the response column
  ___ %>%
  # Collect the results
  ___ %>%
  # Add in the predictions
  mutate(___)
Edit and Run Code