Búsqueda aleatoria
# Call GridSearchCV
grid_search = GridSearchCV(clf, param_grid)
# Fit the model
grid_search.fit(X, y)
En el fragmento de código anterior del ejercicio previo, habrás notado que la primera línea no tardó mucho en ejecutarse, mientras que la llamada a .fit() tardó varios segundos.
Esto se debe a que .fit() es lo que realmente ejecuta la búsqueda en rejilla y, en nuestro caso, la rejilla tenía muchas combinaciones distintas. A medida que la rejilla de hiperparámetros crece, la búsqueda en rejilla se vuelve más lenta. Para resolver este problema, en lugar de probar todas y cada una de las combinaciones posibles, podemos saltar por la rejilla de forma aleatoria y probar combinaciones distintas. Existe una pequeña posibilidad de que no encontremos la combinación óptima, pero ahorraríamos mucho tiempo o podríamos ajustar más hiperparámetros en el mismo tiempo.
En scikit-learn, puedes hacerlo usando RandomizedSearchCV. Tiene la misma API que GridSearchCV, salvo que debes especificar una distribución de parámetros de la que pueda muestrear en lugar de valores concretos. ¡Vamos a probarlo! La distribución de parámetros ya está preparada, junto con un clasificador de random forest llamado clf.
Este ejercicio forma parte del curso
Marketing Analytics: Predicción de churn de clientes en Python
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
# Import RandomizedSearchCV