MulaiMulai sekarang secara gratis

Memprediksi dengan model soybean pada data uji

Pada latihan ini, Anda akan menerapkan model soybean dari latihan sebelumnya (model.lin dan model.gam, sudah dimuat) ke data baru: soybean_test.

Latihan ini adalah bagian dari kursus

Supervised Learning di R: Regresi

Lihat Kursus

Petunjuk latihan

  • Buat kolom soybean_test$pred.lin berisi prediksi dari model linear model.lin.
  • Buat kolom soybean_test$pred.gam berisi prediksi dari model GAM model.gam.
    • Untuk model GAM, metode predict() mengembalikan matriks, jadi gunakan as.numeric() untuk mengonversi matriks menjadi vektor.
  • Lengkapi bagian yang kosong untuk melakukan pivot_longer() pada kolom-kolom prediksi menjadi satu kolom nilai pred dengan kolom kunci modeltype. Beri nama kerangka data panjang tersebut soybean_long.
  • Hitung dan bandingkan RMSE dari kedua model.
    • Model mana yang lebih baik?
  • Jalankan kode untuk membandingkan prediksi tiap model dengan bobot daun rata-rata aktual.
    • Sebuah scatter plot weight sebagai fungsi dari Time.
    • Plot titik-dan-garis dari prediksi (pred) sebagai fungsi dari Time.
    • Perhatikan bahwa model linear kadang memprediksi bobot negatif! Apakah model GAM juga demikian?

Latihan interaktif praktis

Cobalah latihan ini dengan menyelesaikan kode contoh berikut.

# soybean_test is available
summary(soybean_test)

# Get predictions from linear model
soybean_test$pred.lin <- ___(___, newdata = ___)

# Get predictions from gam model
soybean_test$pred.gam <- ___(___(___, newdata = ___))

# Pivot the predictions into a "long" dataset
soybean_long <- soybean_test %>%
  pivot_longer(cols = c(___, ___), names_to = ___, values_to = ___)

# Calculate the rmse
soybean_long %>%
  mutate(residual = weight - pred) %>%     # residuals
  group_by(modeltype) %>%                  # group by modeltype
  summarize(rmse = ___(___(___))) # calculate the RMSE

# Compare the predictions against actual weights on the test data
soybean_long %>%
  ggplot(aes(x = Time)) +                          # the column for the x axis
  geom_point(aes(y = weight)) +                    # the y-column for the scatterplot
  geom_point(aes(y = pred, color = modeltype)) +   # the y-column for the point-and-line plot
  geom_line(aes(y = pred, color = modeltype, linetype = modeltype)) + # the y-column for the point-and-line plot
  scale_color_brewer(palette = "Dark2")
  
Edit dan Jalankan Kode