Daten in Trainings- und Validierungssätze aufteilen
Du hast gelernt, dass die Verwendung nur der Trainingsdaten ohne Validierungsdatensatz zu einem Problem namens Überanpassung führt. Wenn es zu einer Überanpassung kommt, kann das Modell zwar die Daten für die Trainingsinputs super vorhersagen, aber es kann die Daten, die es noch nicht gesehen hat, nur schlecht verallgemeinern. Das heißt, das Modell ist nicht so nützlich, weil es nicht verallgemeinert werden kann. Um das zu vermeiden, kannst du einen Validierungsdatensatz verwenden.
In dieser Übung erstellst du einen Trainings- und Validierungssatz aus dem Datensatz, den du hast (also en_text
mit 1000 englischen Sätzen und fr_text
mit den 1000 französischen Sätzen). Du wirst 80 % des Datensatzes als Trainingsdaten und 20 % als Validierungsdaten verwenden.
Diese Übung ist Teil des Kurses
Maschinelle Übersetzung mit Keras
Anleitung zur Übung
- Definiere eine Folge von Indizes mit „
np.arange()
“, die bei 0 anfängt und die Größe „en_text
“ hat. - Definier „
valid_inds
“ als die letzten Indizes „valid_size
“ aus der Indizes-Sequenz. - Definiere „
tr_en
“ und „tf_fr
“, die die Sätze enthalten, die untertrain_inds
indices zu finden sind, in den Listen „en_text
“ und „fr_text
“. - Definiere „
v_en
” und „v_fr
”, die die Sätze enthalten, die untervalid_inds
indices in den Listen „en_text
” und „fr_text
” zu finden sind.
Interaktive Übung
Versuche dich an dieser Übung, indem du diesen Beispielcode vervollständigst.
train_size, valid_size = 800, 200
# Define a sequence of indices from 0 to len(en_text)
inds = ____.____(len(_____))
np.random.shuffle(inds)
train_inds = inds[:train_size]
# Define valid_inds: last valid_size indices
valid_inds = inds[____]
# Define tr_en (train EN sentences) and tr_fr (train FR sentences)
tr_en = [en_text[____] for ti in ____]
tr_fr = [____ for ti in ____]
# Define v_en (valid EN sentences) and v_fr (valid FR sentences)
v_en = [____ for vi in valid_inds]
v_fr = [____ for vi in ____]
print('Training (EN):\n', tr_en[:3], '\nTraining (FR):\n', tr_fr[:3])
print('\nValid (EN):\n', v_en[:3], '\nValid (FR):\n', v_fr[:3])