Defining the decoder of the inference model
The inference model is the model that will be used out in the wild to perform translations when required by the user. In this exercise, you will need to implement the decoder of the inference model.
The inference model decoder is different to the decoder of the training model. We can't feed the decoder with French words because that is what we want to predict. Luckily, there is a solution. We can use the predicted French word from the previous time step to feed the inference model decoder. Therefore, when you want to generate a translation, the decoder needs to generate one word at a time, while consuming the previous output as an input.

For this exercise, the variables hsize
(hidden size of the GRU
layer), fr_len
and fr_vocab
have been imported. Remember that the prefix de
is used to refer to the decoder.
This exercise is part of the course
Machine Translation with Keras
Exercise instructions
- Define an
Input
layer which accepts a batch of onehot encoded French word sequences (sequence length 1). - Define another
Input
layer which accepts a batch ofhsize
state, which you will use to feed the previous state to the decoder. - Get the output and state of the decoder
GRU
. - Define a model that accepts the French words
Input
and the previous stateInput
and outputs the final prediction and the newGRU
state.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
import tensorflow.keras.layers as layers
from tensorflow.keras.models import Model
# Define an input layer that accepts a single onehot encoded word
de_inputs = layers.____(shape=(____, ____))
# Define an input to accept the t-1 state
de_state_in = layers.____(shape=(____,))
de_gru = layers.GRU(hsize, return_state=True)
# Get the output and state from the GRU layer
de_out, de_state_out = ____(de_inputs, initial_state=____)
de_dense = layers.Dense(fr_vocab, activation='softmax')
de_pred = de_dense(de_out)
# Define a model
decoder = Model(inputs=[____, ____], outputs=[____, ____])
print(decoder.summary())