Get startedGet started for free

Training the NMT model

1. Training the NMT model

You will now learn how to train the model.

2. Revisiting the model

So far you have an encoder decoder based neural machine translator model which has an Encoder GRU that consumes English words and outputs a context vector. Then you have a Decoder GRU which consumes the context vector and outputs a sequence of GRU outputs. Finally you have a Decoder prediction layer which outputs the prediction probabilities.

3. Optimizing the parameters

When you implement a Keras model, it has parameters. For example, all GRU layers and Dense layers have parameters. These parameters are known as weights and biases, which we saw when discussing the Dense layer. They are responsible for converting the inputs to an useful output. For example, converting the onehot encoded English words to probabilistic distributions of a sequence of French words. Typically these parameters are initialized randomly. To make these parameters useful you need to train them. During which, these parameters will be changed to produce meaningful outputs. To successfully train the model you need to define a loss function and an optimizer using the compile function. The optimizer will compute loss on the training data. The loss is computed by first generating predictions with the inputs and then measuring the difference between the predictions and actual outputs or targets. Finally the parameters are changed in a way that they minimize the computed loss.

4. Training the model

To train the model, you iterate through the training dataset in batches. A single batch is referred to as a single iteration. A single traverse through all the batches is called a single epoch. Then, you do multiple epochs over the training dataset. Next, at each iteration, you get a batch of inputs and a batch of outputs using the sents2seqs function. Note how you are specifying different language-specific attributes, like the tokenizer. This is done using the first argument, which can be "source" or "target". And you will be reversing the encoder text. Then, call the train_on_batch function with the batch of inputs and targets to train the model. Finally, you can evaluate the model by calling the evaluate function and get the metrics out. The evaluation call will result in two metrics: the loss and the accuracy.

5. Training the model

res will be a tuple with first element being the loss and the second being the accuracy. Accuracy is computed as the total number of correctly classified samples divided by the total number of samples multiplied by 100. Remember that each epoch has multiple iterations.

6. Avoiding overfitting

However, in practice you should use two datasets. A training set to train the model on, and a validation set to monitor the accuracy on. With this mechanism, you can stop the training when the validation accuracy stops increasing and prevent overfitting.

7. Splitting the dataset

Let's consider a dataset of 1000 sentences and split the dataset to 800 training samples and 200 validation samples. You can use shuffling to make sure that you are not biasing the datasets. You can then get the first 800 indices as training indices and last 200 as validation indices.

8. Splitting the dataset

You then split the data to training and validation data by getting the data corresponding to train indices as training data, that is tr_en and tr_fr, and the valid indices as validation data, v_en and v_fr.

9. Training the model with validation

When training the model with a validation step, the training bit stays the same. Then, in each epoch, you will be evaluating the model on a validation dataset which is v_en_x and v_de_y. v_en_x and v_de_y needs to be prepared using the same transformations used for en_x and de_y. Then, you use the evaluate function to get res, which contains the loss and the accuracy on the validation dataset. Since you are using a small validation set, let's set the batch_size to the validation set size.

10. Let's practice!

You now know how to train your model and keep a look out for overfitting. Let's practice!