Predict with the soybean model on test data
In this exercise, you will apply the soybean models from the previous exercise (model.lin and model.gam, already loaded) to new data: soybean_test.
Este ejercicio forma parte del curso
Supervised Learning in R: Regression
Instrucciones del ejercicio
- Create a column
soybean_test$pred.linwith predictions from the linear modelmodel.lin. - Create a column
soybean_test$pred.gamwith predictions from the gam modelmodel.gam.- For GAM models, the
predict()method returns a matrix, so useas.numeric()to convert the matrix to a vector.
- For GAM models, the
- Fill in the blanks to
pivot_longer()the prediction columns into a single value columnpredwith key columnmodeltype. Call the long data framesoybean_long. - Calculate and compare the RMSE of both models.
- Which model does better?
- Run the code to compare the predictions of each model against the actual average leaf weights.
- A scatter plot of
weightas a function ofTime. - Point-and-line plots of the predictions (
pred) as a function ofTime. - Notice that the linear model sometimes predicts negative weights! Does the gam model?
- A scatter plot of
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
# soybean_test is available
summary(soybean_test)
# Get predictions from linear model
soybean_test$pred.lin <- ___(___, newdata = ___)
# Get predictions from gam model
soybean_test$pred.gam <- ___(___(___, newdata = ___))
# Pivot the predictions into a "long" dataset
soybean_long <- soybean_test %>%
pivot_longer(cols = c(___, ___), names_to = ___, values_to = ___)
# Calculate the rmse
soybean_long %>%
mutate(residual = weight - pred) %>% # residuals
group_by(modeltype) %>% # group by modeltype
summarize(rmse = ___(___(___))) # calculate the RMSE
# Compare the predictions against actual weights on the test data
soybean_long %>%
ggplot(aes(x = Time)) + # the column for the x axis
geom_point(aes(y = weight)) + # the y-column for the scatterplot
geom_point(aes(y = pred, color = modeltype)) + # the y-column for the point-and-line plot
geom_line(aes(y = pred, color = modeltype, linetype = modeltype)) + # the y-column for the point-and-line plot
scale_color_brewer(palette = "Dark2")