1. Nauka
  2. /
  3. Kursy
  4. /
  5. Wprowadzenie do Spark z pakietem sparklyr w R

Connected

ćwiczenie

Gradient boosted trees: predykcja

Po uruchomieniu modelu kolejnym krokiem jest wykonanie predykcji. W odróżnieniu od bazowego R, które do predykcji używa funkcji predict(), pakiet sparklyr korzysta z funkcji ml_predict(). Funkcja ml_predict() przyjmuje dwa argumenty: model oraz dane testowe.

ml_predict(a_model, testing_data)

Częstym podejściem jest porównanie przewidywanych wartości z rzeczywistymi, które można następnie zwizualizować w R. Schemat kodu do przygotowania takich danych wygląda następująco. Zwróć uwagę, że dodanie kolumny z predykcją musi być wykonane lokalnie, dlatego najpierw trzeba zebrać wyniki.

predicted_vs_actual <- testing_data %>%
  select(actual) %>%
  collect() %>%
  mutate(predicted)

Instrukcje

100 XP

Połączenie ze Sparkiem zostało już utworzone i jest dostępne jako spark_conn. Tibble powiązane ze zbiorami treningowym i testowym przechowywanymi w Sparku są predefiniowane odpowiednio jako track_data_to_model_tbl i track_data_to_predict_tbl. Model gradient boosted trees jest predefiniowany jako gradient_boosted_trees_model.

  • Zdefiniuj zmienną predicted, która będzie zawierać predykcje modelu dla danych testowych.
    • Wywołaj funkcję ml_predict(), przekazując jako argumenty model oraz dane testowe. Funkcja ta wygeneruje predykcje dla zbioru testowego i doda je jako nową kolumnę o nazwie prediction.
    • Używając funkcji pull(), wyodrębnij tę kolumnę i przypisz ją do zmiennej predicted.
  • Zdefiniuj zmienną responses, aby przygotować dane do porównania wartości przewidywanych z rzeczywistymi:
    • Wybierz kolumnę z odpowiedzią year.
    • Zbierz wyniki.
    • Użyj funkcji mutate(), aby dodać predykcje zapisane w zmiennej predicted.