LoslegenKostenlos loslegen

Das Modell trainieren

In dieser Übung trainierst du das zuvor implementierte Modell. Wusstest du, dass das auf einem Encoder-Decoder basierende maschinelle Übersetzungsmodell von Google 2 bis 4 Tage zum Trainieren gebraucht hat?

Für diese Übung wirst du einen kleinen Datensatz mit 1500 Sätzen ( en_text und fr_text) zum Trainieren des Modells verwenden. Dieser Betrag wird kaum für eine gute Leistung reichen, aber die Methode bleibt die gleiche. Es geht darum, länger mit mehr Daten zu trainieren. Du hast auch die zuvor implementierten Funktionen „ nmt “ und „ sents2seqs() “ bekommen. In dieser Übung kehrst du den Encoder-Text um, um eine bessere Leistung zu erzielen. Hier bezieht sich „ en_x “ auf den Encoder-Eingang, während „ de_x “ den Decoder-Eingang bedeutet.

Diese Übung ist Teil des Kurses

Maschinelle Übersetzung mit Keras

Kurs anzeigen

Anleitung zur Übung

  • Hol dir mit der Funktion „ sents2seqs() “ einen einzelnen Stapel von Encoder-Eingaben (englische Sätze von Index i bis i+bsize). Die Eingaben müssen umgekehrt und mit One-Hot-Kodierung versehen werden.
  • Hol dir mit der Funktion „ sents2seqs() “ eine einzelne Charge von Decoder-Ausgaben (französische Sätze von Index i bis i+bsize). Die Eingaben müssen mit One-Hot-Kodierung kodiert sein.
  • Trainiere das Modell mit einem einzigen Datenbatch, der „ en_x “ und „ de_y “ enthält.
  • Hol dir die Bewertungsmetriken für „ en_x ” und „ de_y ”, indem du das Modell mit einem „ batch_size ” von „ bsize ” bewertest.

Interaktive Übung

Versuche dich an dieser Übung, indem du diesen Beispielcode vervollständigst.

n_epochs, bsize = 3, 250

for ei in range(n_epochs):
  for i in range(0,data_size,bsize):
    # Get a single batch of encoder inputs
    en_x = ____('source', ____, onehot=____, reverse=____)
    # Get a single batch of decoder outputs
    de_y = sents2seqs('target', fr_text[____], onehot=____)
    
    # Train the model on a single batch of data
    nmt.____(____, ____)    
    # Obtain the eval metrics for the training data
    res = nmt.____(____, de_y, batch_size=____, verbose=0)
    print("{} => Train Loss:{}, Train Acc: {}".format(ei+1,res[0], res[1]*100.0))  
Code bearbeiten und ausführen