ALS model buildout on MovieLens Data
1. ALS model buildout on MovieLens Data
If you remember from the last chapter, you built out a model on the ratings dataset. The code looked like this:2. Fitting a basic model
Now, the RMSE that you got was lower than the 1.45 shown here. But what if you went through this whole process and got an error metric that you weren't satisfied with, like this RMSE of 1.45. You might want to try other combinations of hyperparameter values to try and reduce that. Spark makes it easy to do this by using two additional tools called3. Intro to ParamGridBuilder and CrossValidator
the ParamGridBuilder and the CrossValidator. These tools will allow you to try many different hyperparameter values and have Spark identify the best combination. Let's talk about how to use them.4. ParamGridBuilder
The ParamGridBuilder tells Spark all the hyperparameter values you want it to try. To do this, we first import the ParamGridBuilder package, instantiate it and give it a name. We'll call it param_grid. We then add each hyperparameter name calling the .addGrid()5. Adding Hyperparameters to the ParamGridBuilder
method on our als algorithm and hyperparameter name as you see here. Notice the empty lists to the right of the hyperparameter names. This is where we input the values we want Spark to try for each hyperparameter, like this:6. Adding Hyperparameter Values to the ParamGridBuilder
Once we've added all of this, we call the .build() method to complete the build of our param_grid. Now let's look at the CrossValidator.7. CrossValidator
The CrossValidator essentially fits a model to several different portions of our training dataset called folds, and then generates predictions for each respective holdout portion of the dataset to see how it performs.8. CrossValidator instantiation and estimator
To properly use the CrossValidator, we first import the CrossValidator package, instantiate a CrossValidator and give it a name, we'll call it cv here.9. CrossValidator ParamMaps
We then tell it to use our als model as an estimator by setting estimator argument equal to the name of our model which is als. We'll set the estimatorParamMaps to our param_grid that we built so that Spark knows what values to try as it works to identify the best combination of hyperparameters. Then we provide the name of our evaluator so it knows how to measure each model's performance by simply setting the evaluator argument to the name of our evaluator which is "evaluator".10. CrossValidator
We finish by setting the numFolds argument to the number of times we want Spark to test each model on the training data, in this case, 5 times. Let's go over how to integrate these into a full code buildout.11. Random split
We'll first split our data into training and test sets using the randomSplit() method and we'll build a generic ALS model without any hyperparameters, only the model parameters as you see here. The cross validator will take care of the hyperparameters.12. ParamGridBuilder
We'll build our ParamGridBuilder so Spark knows what hyperparameter values to test.13. Evaluator
We'll create an evaluator so Spark knows how to evaluate each model.14. CrossValidator
Then the CrossValidator will tell Spark the algorithm, the hyperparameters and values, and the evaluator to use to find the best model, and the number of training set folds we want each model to be tested on.15. Best model
We then fit our CrossValidator on the training data to have Spark try all the combindations of hyperparameters we specified by calling the cv.fit() method on the training data. Once it's finished running, we extract the best-performing model by calling the bestModel() method on our model. We'll call this our best_model and16. Predictions and performance evaluation
with it, we can generate predictions on the test set, print the error metric and the respective hyperparameter values using the code you see here. And now we have our cross-validated model.17. Let's practice!
Let's build a real model on a real dataset.Create Your Free Account
or
By continuing, you accept our Terms of Use, our Privacy Policy and that your data is stored in the USA.