Get startedGet started for free

Defining the decoder of the inference model

The inference model is the model that will be used out in the wild to perform translations when required by the user. In this exercise, you will need to implement the decoder of the inference model.

The inference model decoder is different to the decoder of the training model. We can't feed the decoder with French words because that is what we want to predict. Luckily, there is a solution. We can use the predicted French word from the previous time step to feed the inference model decoder. Therefore, when you want to generate a translation, the decoder needs to generate one word at a time, while consuming the previous output as an input.

For this exercise, the variables hsize (hidden size of the GRU layer), fr_len and fr_vocab have been imported. Remember that the prefix de is used to refer to the decoder.

This exercise is part of the course

Machine Translation with Keras

View Course

Exercise instructions

  • Define an Input layer which accepts a batch of onehot encoded French word sequences (sequence length 1).
  • Define another Input layer which accepts a batch of hsize state, which you will use to feed the previous state to the decoder.
  • Get the output and state of the decoder GRU.
  • Define a model that accepts the French words Input and the previous state Input and outputs the final prediction and the new GRU state.

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

import tensorflow.keras.layers as layers
from tensorflow.keras.models import Model
# Define an input layer that accepts a single onehot encoded word
de_inputs = layers.____(shape=(____, ____))
# Define an input to accept the t-1 state
de_state_in = layers.____(shape=(____,))
de_gru = layers.GRU(hsize, return_state=True)
# Get the output and state from the GRU layer
de_out, de_state_out = ____(de_inputs, initial_state=____)
de_dense = layers.Dense(fr_vocab, activation='softmax')
de_pred = de_dense(de_out)

# Define a model
decoder = Model(inputs=[____, ____], outputs=[____, ____])
print(decoder.summary())
Edit and Run Code