Diviser les données en ensembles d'entraînement et de validation
Vous avez appris que l'utilisation exclusive des données d'entraînement sans ensemble de données de validation entraîne un problème appelé « surapprentissage ». En cas de surajustement, le modèle sera très performant pour prédire les données utilisées pour l'entraînement, mais très peu pour généraliser à des données inconnues. Cela signifie que le modèle ne sera pas très utile, car il ne peut pas être généralisé. Pour éviter cela, vous pouvez utiliser un ensemble de données de validation.
Dans cet exercice, vous allez créer un ensemble d'entraînement et de validation à partir de l'ensemble de données dont vous disposez (à savoir en_text contenant 1 000 phrases en anglais et fr_text contenant les 1 000 phrases en français). Vous utiliserez 80 % de l'ensemble de données pour l'entraînement et 20 % pour la validation.
Cet exercice fait partie du cours
Traduction automatique avec Keras
Instructions
- Définissez une séquence d'index à l'aide d'
np.arange(), qui commence par 0 et a une taille deen_text. - Définissez l'
valid_indscomme étant les derniers indices d'valid_sizes de la séquence d'indices. - Définissez
tr_enettf_fr, qui contiennent les phrases trouvées à l'adressetrain_indsindices, dans les listesen_textetfr_text. - Définissez
v_enetv_frqui contiennent les phrases trouvées à l'adressevalid_indsindices, dans les listesen_textetfr_text.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
train_size, valid_size = 800, 200
# Define a sequence of indices from 0 to len(en_text)
inds = ____.____(len(_____))
np.random.shuffle(inds)
train_inds = inds[:train_size]
# Define valid_inds: last valid_size indices
valid_inds = inds[____]
# Define tr_en (train EN sentences) and tr_fr (train FR sentences)
tr_en = [en_text[____] for ti in ____]
tr_fr = [____ for ti in ____]
# Define v_en (valid EN sentences) and v_fr (valid FR sentences)
v_en = [____ for vi in valid_inds]
v_fr = [____ for vi in ____]
print('Training (EN):\n', tr_en[:3], '\nTraining (FR):\n', tr_fr[:3])
print('\nValid (EN):\n', v_en[:3], '\nValid (FR):\n', v_fr[:3])