1. Learn
  2. /
  3. Courses
  4. /
  5. sparklyr を使った Spark 入門(R)

Connected

Exercise

勾配ブースティング木:予測

モデルを実行したら、次のステップは予測を行うことです。ベース R では predict() 関数を使って予測しますが、sparklyr では ml_predict() 関数を使います。ml_predict() は、モデルとテストデータの 2 つの引数を受け取ります。

ml_predict(a_model, testing_data)

よくある使い方として、予測値と実際の値を比較し、R でグラフを描画する方法があります。そのためのデータ準備のコードパターンは以下のとおりです。なお、現時点では予測列の追加はローカルで行う必要があるため、先に結果を収集しておく必要があります。

predicted_vs_actual <- testing_data %>%
  select(actual) %>%
  collect() %>%
  mutate(predicted)

Instructions

100 XP

Spark への接続は spark_conn として事前に作成されています。Spark に保存されたトレーニングデータとテストデータに紐付けられたティブルは、それぞれ track_data_to_model_tbl と track_data_to_predict_tbl として定義済みです。勾配ブースティング木モデルは gradient_boosted_trees_model として定義済みです。

  • テストデータに対するモデルの予測結果を格納する変数 predicted を定義しましょう。
    • モデルとテストデータを引数として ml_predict() を呼び出します。この関数はテストデータセットに対して予測を生成し、prediction という名前の新しい列として追加します。
    • pull() を使ってこの列を取り出し、predicted に代入します。
  • 予測値と実際の値を比較するためのデータを準備する変数 responses を定義しましょう。
    • レスポンス列 year を選択します。
    • 結果を収集します。
    • mutate() を使って、predicted で作成した予測値を追加します。