Diviser les données en ensembles d'entraînement et de validation
Vous avez appris que l'utilisation exclusive des données d'entraînement sans ensemble de données de validation entraîne un problème appelé « surapprentissage ». En cas de surajustement, le modèle sera très performant pour prédire les données utilisées pour l'entraînement, mais très peu pour généraliser à des données inconnues. Cela signifie que le modèle ne sera pas très utile, car il ne peut pas être généralisé. Pour éviter cela, vous pouvez utiliser un ensemble de données de validation.
Dans cet exercice, vous allez créer un ensemble d'entraînement et de validation à partir de l'ensemble de données dont vous disposez (à savoir en_text
contenant 1 000 phrases en anglais et fr_text
contenant les 1 000 phrases en français). Vous utiliserez 80 % de l'ensemble de données pour l'entraînement et 20 % pour la validation.
Cet exercice fait partie du cours
Traduction automatique avec Keras
Instructions
- Définissez une séquence d'index à l'aide d'
np.arange()
, qui commence par 0 et a une taille deen_text
. - Définissez l'
valid_inds
comme étant les derniers indices d'valid_size
s de la séquence d'indices. - Définissez
tr_en
ettf_fr
, qui contiennent les phrases trouvées à l'adressetrain_inds
indices, dans les listesen_text
etfr_text
. - Définissez
v_en
etv_fr
qui contiennent les phrases trouvées à l'adressevalid_inds
indices, dans les listesen_text
etfr_text
.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
train_size, valid_size = 800, 200
# Define a sequence of indices from 0 to len(en_text)
inds = ____.____(len(_____))
np.random.shuffle(inds)
train_inds = inds[:train_size]
# Define valid_inds: last valid_size indices
valid_inds = inds[____]
# Define tr_en (train EN sentences) and tr_fr (train FR sentences)
tr_en = [en_text[____] for ti in ____]
tr_fr = [____ for ti in ____]
# Define v_en (valid EN sentences) and v_fr (valid FR sentences)
v_en = [____ for vi in valid_inds]
v_fr = [____ for vi in ____]
print('Training (EN):\n', tr_en[:3], '\nTraining (FR):\n', tr_fr[:3])
print('\nValid (EN):\n', v_en[:3], '\nValid (FR):\n', v_fr[:3])