Aan de slagGa gratis aan de slag

Voorspellen met het sojaboonmodel op testdata

In deze oefening pas je de sojaboonmodellen uit de vorige oefening (model.lin en model.gam, al geladen) toe op nieuwe data: soybean_test.

Deze oefening maakt deel uit van de cursus

Supervised Learning in R: Regressie

Cursus bekijken

Oefeninstructies

  • Maak een kolom soybean_test$pred.lin met voorspellingen van het lineaire model model.lin.
  • Maak een kolom soybean_test$pred.gam met voorspellingen van het gam-model model.gam.
    • Voor GAM-modellen geeft predict() een matrix terug, dus gebruik as.numeric() om de matrix naar een vector om te zetten.
  • Vul de lege plekken in om met pivot_longer() de voorspellingskolommen om te zetten naar één waardekolom pred met sleutelkolom modeltype. Noem het lange data frame soybean_long.
  • Bereken en vergelijk de RMSE van beide modellen.
    • Welk model presteert beter?
  • Voer de code uit om de voorspellingen van elk model te vergelijken met de werkelijke gemiddelde bladgewichten.
    • Een spreidingsdiagram van weight als functie van Time.
    • Punt-en-lijngrafieken van de voorspellingen (pred) als functie van Time.
    • Merk op dat het lineaire model soms negatieve gewichten voorspelt! Doet het gam-model dat ook?

Praktische interactieve oefening

Probeer deze oefening eens door deze voorbeeldcode in te vullen.

# soybean_test is available
summary(soybean_test)

# Get predictions from linear model
soybean_test$pred.lin <- ___(___, newdata = ___)

# Get predictions from gam model
soybean_test$pred.gam <- ___(___(___, newdata = ___))

# Pivot the predictions into a "long" dataset
soybean_long <- soybean_test %>%
  pivot_longer(cols = c(___, ___), names_to = ___, values_to = ___)

# Calculate the rmse
soybean_long %>%
  mutate(residual = weight - pred) %>%     # residuals
  group_by(modeltype) %>%                  # group by modeltype
  summarize(rmse = ___(___(___))) # calculate the RMSE

# Compare the predictions against actual weights on the test data
soybean_long %>%
  ggplot(aes(x = Time)) +                          # the column for the x axis
  geom_point(aes(y = weight)) +                    # the y-column for the scatterplot
  geom_point(aes(y = pred, color = modeltype)) +   # the y-column for the point-and-line plot
  geom_line(aes(y = pred, color = modeltype, linetype = modeltype)) + # the y-column for the point-and-line plot
  scale_color_brewer(palette = "Dark2")
  
Code bewerken en uitvoeren