Recherche aléatoire
# Call GridSearchCV
grid_search = GridSearchCV(clf, param_grid)
# Fit the model
grid_search.fit(X, y)
Dans l’extrait de code ci‑dessus, issu de l’exercice précédent, vous avez peut‑être remarqué que la première ligne s’exécute très rapidement, alors que l’appel à .fit() prend plusieurs secondes.
C’est parce que .fit() effectue réellement la recherche sur grille, et dans notre cas, la grille comportait de nombreuses combinaisons. Plus la grille d’hyperparamètres est grande, plus la recherche sur grille devient lente. Pour contourner ce problème, plutôt que d’essayer toutes les combinaisons possibles, on peut parcourir la grille de façon aléatoire et tester différentes combinaisons. Il y a un faible risque de passer à côté de la meilleure combinaison, mais on gagne beaucoup de temps, ou on peut ajuster davantage d’hyperparamètres dans le même laps de temps.
Avec scikit-learn, vous pouvez faire cela avec RandomizedSearchCV. Son API est identique à celle de GridSearchCV, sauf que vous devez fournir une distribution de paramètres dans laquelle échantillonner au lieu de valeurs d’hyperparamètres précises. Essayons maintenant ! La distribution de paramètres a été préparée pour vous, ainsi qu’un classifieur random forest appelé clf.
Cet exercice fait partie du cours
Marketing Analytics : prédire l’attrition client en Python
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Import RandomizedSearchCV