Validação cruzada K-fold
Você vai trabalhar com um problema de classificação binária usando uma subamostra de uma competição playground do Kaggle. O objetivo da competição é prever se o famoso jogador de basquete Kobe Bryant acertou a cesta ou errou um arremesso específico.
Os dados de treino estão disponíveis no seu workspace como o DataFrame bryant_shots. Ele contém dados de 10.000 arremessos com suas propriedades e a variável target "shot\_made\_flag" — indicando se o arremesso foi convertido ou não.
Uma das features nos dados é "game_id" — o jogo específico em que o arremesso ocorreu. Existem 541 jogos distintos. Ou seja, você está lidando com uma variável categórica de alta cardinalidade. Vamos codificá-la usando a média do target!
Suponha que você esteja usando validação cruzada com 5 folds e quer avaliar uma feature codificada por média do target na validação local.
Este exercício faz parte do curso
Vencendo uma competição do Kaggle em Python
Instruções do exercício
- Para isso, você precisa repetir o procedimento de codificação para a variável categórica
"game_id"separadamente dentro de cada divisão dos folds. Seu objetivo é especificar todos os parâmetros que faltam na chamada da funçãomean_target_encoding()dentro de cada divisão dos folds. - Lembre-se de que os parâmetros
trainetestesperam os DataFrames de treino e de teste. - Já os parâmetros
targetecategoricalesperam os nomes da variável target e da variável categórica a ser codificada.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# Create 5-fold cross-validation
kf = KFold(n_splits=5, random_state=123, shuffle=True)
# For each folds split
for train_index, test_index in kf.split(bryant_shots):
cv_train, cv_test = bryant_shots.iloc[train_index], bryant_shots.iloc[test_index]
# Create mean target encoded feature
cv_train['game_id_enc'], cv_test['game_id_enc'] = mean_target_encoding(train=cv_train,
test=____,
target='shot_made_flag',
categorical='____',
alpha=5)
# Look at the encoding
print(cv_train[['game_id', 'shot_made_flag', 'game_id_enc']].sample(n=1))