Gradient boosted trees: previsão
Depois de executar seu modelo, o próximo passo é fazer previsões com ele. Diferentemente do base R, que usa a função predict() para prever, o sparklyr usa a função ml_predict(). ml_predict() recebe dois argumentos: um modelo e alguns dados de teste.
ml_predict(a_model, testing_data)
Um caso de uso comum é comparar as respostas previstas com as respostas reais, o que você pode visualizar com gráficos no R. O padrão de código para preparar esses dados é o seguinte. Observe que, no momento, adicionar uma coluna de predição precisa ser feito localmente, então você deve coletar os resultados primeiro.
predicted_vs_actual <- testing_data %>%
select(actual) %>%
collect() %>%
mutate(predicted)
Este exercício faz parte do curso
Introdução ao Spark com sparklyr em R
Instruções do exercício
Uma conexão com o Spark foi criada para você como spark_conn. Tibbles anexados aos conjuntos de dados de treino e teste armazenados no Spark já foram definidos como track_data_to_model_tbl e track_data_to_predict_tbl, respectivamente. O modelo de gradient boosted trees já foi definido como gradient_boosted_trees_model.
- Defina uma variável
predictedque contenha as previsões do modelo para nossos dados de teste.- Chame
ml_predict()com o modelo e os dados de teste como argumentos. Essa função vai gerar previsões para o conjunto de teste e adicioná-las como uma nova coluna chamadaprediction. - Usando
pull(), podemos extrair essa coluna e atribuí-la apredicted.
- Chame
- Defina a variável
responsespara preparar os dados para comparar as respostas previstas com as reais:- Selecione a coluna de resposta
year. - Colete os resultados.
- Use
mutate()para adicionar as previsões feitas empredicted.
- Selecione a coluna de resposta
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# Training, testing sets & model are pre-defined
track_data_to_model_tbl
track_data_to_predict_tbl
gradient_boosted_trees_model
# Predict the responses for the testing data
predicted <- ___(
___,
___) %>% pull(prediction)
# Prepare the data for comparing predicted responses with actual responses
responses <- track_data_to_predict_tbl %>%
# Select the response column
___ %>%
# Collect the results
___ %>%
# Add in the predictions
mutate(___)