Memprediksi dengan model soybean pada data uji
Pada latihan ini, Anda akan menerapkan model soybean dari latihan sebelumnya (model.lin dan model.gam, sudah dimuat) ke data baru: soybean_test.
Latihan ini adalah bagian dari kursus
Supervised Learning di R: Regresi
Petunjuk latihan
- Buat kolom
soybean_test$pred.linberisi prediksi dari model linearmodel.lin. - Buat kolom
soybean_test$pred.gamberisi prediksi dari model GAMmodel.gam.- Untuk model GAM, metode
predict()mengembalikan matriks, jadi gunakanas.numeric()untuk mengonversi matriks menjadi vektor.
- Untuk model GAM, metode
- Lengkapi bagian yang kosong untuk melakukan
pivot_longer()pada kolom-kolom prediksi menjadi satu kolom nilaipreddengan kolom kuncimodeltype. Beri nama kerangka data panjang tersebutsoybean_long. - Hitung dan bandingkan RMSE dari kedua model.
- Model mana yang lebih baik?
- Jalankan kode untuk membandingkan prediksi tiap model dengan bobot daun rata-rata aktual.
- Sebuah scatter plot
weightsebagai fungsi dariTime. - Plot titik-dan-garis dari prediksi (
pred) sebagai fungsi dariTime. - Perhatikan bahwa model linear kadang memprediksi bobot negatif! Apakah model GAM juga demikian?
- Sebuah scatter plot
Latihan interaktif praktis
Cobalah latihan ini dengan menyelesaikan kode contoh berikut.
# soybean_test is available
summary(soybean_test)
# Get predictions from linear model
soybean_test$pred.lin <- ___(___, newdata = ___)
# Get predictions from gam model
soybean_test$pred.gam <- ___(___(___, newdata = ___))
# Pivot the predictions into a "long" dataset
soybean_long <- soybean_test %>%
pivot_longer(cols = c(___, ___), names_to = ___, values_to = ___)
# Calculate the rmse
soybean_long %>%
mutate(residual = weight - pred) %>% # residuals
group_by(modeltype) %>% # group by modeltype
summarize(rmse = ___(___(___))) # calculate the RMSE
# Compare the predictions against actual weights on the test data
soybean_long %>%
ggplot(aes(x = Time)) + # the column for the x axis
geom_point(aes(y = weight)) + # the y-column for the scatterplot
geom_point(aes(y = pred, color = modeltype)) + # the y-column for the point-and-line plot
geom_line(aes(y = pred, color = modeltype, linetype = modeltype)) + # the y-column for the point-and-line plot
scale_color_brewer(palette = "Dark2")