Prever com o modelo de soja nos dados de teste
Neste exercício, você vai aplicar os modelos de soja do exercício anterior (model.lin e model.gam, já carregados) a novos dados: soybean_test.
Este exercício faz parte do curso
Aprendizado Supervisionado em R: Regressão
Instruções do exercício
- Crie uma coluna
soybean_test$pred.lincom previsões do modelo linearmodel.lin. - Crie uma coluna
soybean_test$pred.gamcom previsões do modelo GAMmodel.gam.- Para modelos GAM, o método
predict()retorna uma matriz, então useas.numeric()para converter a matriz em vetor.
- Para modelos GAM, o método
- Preencha as lacunas para usar
pivot_longer()e transformar as colunas de previsão em uma única coluna de valorespred, com a coluna de chavemodeltype. Chame o data frame no formato longo desoybean_long. - Calcule e compare o RMSE de ambos os modelos.
- Qual modelo é melhor?
- Execute o código para comparar as previsões de cada modelo com os pesos médios reais das folhas.
- Um gráfico de dispersão de
weightem função deTime. - Gráficos de pontos e linhas das previsões (
pred) em função deTime. - Observe que o modelo linear às vezes prevê pesos negativos! O modelo GAM faz isso?
- Um gráfico de dispersão de
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# soybean_test is available
summary(soybean_test)
# Get predictions from linear model
soybean_test$pred.lin <- ___(___, newdata = ___)
# Get predictions from gam model
soybean_test$pred.gam <- ___(___(___, newdata = ___))
# Pivot the predictions into a "long" dataset
soybean_long <- soybean_test %>%
pivot_longer(cols = c(___, ___), names_to = ___, values_to = ___)
# Calculate the rmse
soybean_long %>%
mutate(residual = weight - pred) %>% # residuals
group_by(modeltype) %>% # group by modeltype
summarize(rmse = ___(___(___))) # calculate the RMSE
# Compare the predictions against actual weights on the test data
soybean_long %>%
ggplot(aes(x = Time)) + # the column for the x axis
geom_point(aes(y = weight)) + # the y-column for the scatterplot
geom_point(aes(y = pred, color = modeltype)) + # the y-column for the point-and-line plot
geom_line(aes(y = pred, color = modeltype, linetype = modeltype)) + # the y-column for the point-and-line plot
scale_color_brewer(palette = "Dark2")