Defining the embedding model
You will be defining a Keras model that:
- Uses
Embedding
layers - Will be trained with Teacher Forcing
This model will have two embedding layers; an encoder embedding layer and a decoder embedding layer. Furthermore, as the model is trained using Teacher Forcing, it will use a sequence length of fr_len-1
in the decoder Input
layer.
For this exercise, you have all the required keras.layers
and Model
imported. Furthermore the variables, en_len
(English sequence length), fr_len
(French sequence length), en_vocab
(English vocabulary size), fr_vocab
(French vocabulary size) and hsize
(hidden size) have been defined.
This exercise is part of the course
Machine Translation with Keras
Exercise instructions
- Define an
Input
layer which accepts a sequence of word IDs. - Define an
Embedding
layer that embedsen_vocab
words, has length 96 and can accept a sequence of IDs (sequence length is specified using theinput_length
argument). - Define an
Embedding
layer that embedsfr_vocab
words, has length 96 and can accept a sequence offr_len-1
IDs. - Define a model that takes an input from the encoder and an input from the decoder (in that order) and outputs the word predictions.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Define an input layer which accepts a sequence of word IDs
en_inputs = Input(____=(____,))
# Define an Embedding layer which accepts en_inputs
en_emb = ____(____, ____, input_length=____)(en_inputs)
en_out, en_state = GRU(hsize, return_state=True)(en_emb)
de_inputs = Input(shape=(fr_len-1,))
# Define an Embedding layer which accepts de_inputs
de_emb = Embedding(____, 96, input_length=____)(____)
de_out, _ = GRU(hsize, return_sequences=True, return_state=True)(de_emb, initial_state=en_state)
de_pred = TimeDistributed(Dense(fr_vocab, activation='softmax'))(de_out)
# Define the Model which accepts encoder/decoder inputs and outputs predictions
nmt_emb = Model([____, ____], ____)
nmt_emb.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])