Ajuste um modelo xgboost de aluguel de bicicletas e faça previsões
Neste exercício, você vai ajustar um modelo de gradient boosting usando xgboost() para prever o número de bicicletas alugadas em uma hora em função do clima, do tipo e do horário do dia. Você vai treinar o modelo com dados do mês de julho e prever com dados do mês de agosto.
Os data frames bikesJuly, bikesJuly.treat, bikesAugust e bikesAugust.treat já foram pré-carregados. Lembre-se de que os dados tratados com vtreat não têm mais a coluna de desfecho, então você deve obtê-la dos dados originais (a coluna cnt).
Para sua conveniência, o número de árvores a usar, ntrees do exercício anterior, está disponível.
Os argumentos de xgboost() (docs) são semelhantes aos de xgb.cv().
Este exercício faz parte do curso
Aprendizado Supervisionado em R: Regressão
Instruções do exercício
- Preencha os espaços para executar
xgboost()nos dados de julho.- Use
as.matrix()para converter o data frame tratado pelo vtreat em matriz. - O objective deve ser
"reg:squarederror". - Use
ntreesrounds. - Defina
etacomo0.75,max_depthcomo5everbosecomoFALSE(silencioso).
- Use
- Agora chame
predict()embikesAugust.treatpara prever o número de bicicletas alugadas em agosto.- Use
as.matrix()para converter os dados de teste tratados comvtreatem matriz. - Adicione as previsões a
bikesAugustcomo a colunapred.
- Use
- Preencha os espaços para plotar as contagens reais de aluguel de bicicletas versus as previsões (previsões no eixo x).
- Você percebe um possível problema nas previsões?
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# Run xgboost
bike_model_xgb <- xgboost(data = ___, # training data as matrix
label = ___, # column of outcomes
nrounds = ___, # number of trees to build
objective = ___, # objective
eta = ___,
max_depth = ___,
verbose = FALSE # silent
)
# Make predictions
bikesAugust$pred <- ___(___, ___(___))
# Plot predictions (on x axis) vs actual bike rental count
ggplot(bikesAugust, aes(x = ___, y = ___)) +
geom_point() +
geom_abline()