1. Encoder-decoder transformers
Let's build a full encoder-decoder transformer for sequence-to-sequence language tasks like text translation.
2. Encoder meets decoder
To connect our encoder and decoder bodies, we need to direct the encoder outputs into the decoder.
3. Encoder meets decoder
To do this, we need one more component: a cross-attention mechanism.
4. Cross-attention mechanism
Cross-attention is another variant of the self-attention mechanism. It occurs in each decoder layer after the masked attention, taking two inputs:
the information processed through the decoder,
and the final hidden states produced by the encoder,
thereby linking the two transformer blocks.
This is crucial for the decoder to "look back" at the input sequence to figure out what to generate next in the target sequence.
Consider this English-to-Spanish translation example. Given the sequence "I really like to travel", cross-attention identifies key words in the original English sequence to generate the next Spanish word: in this case, "travel" receives the highest attention.
5. Modifying the DecoderLayer
This additional attention stage is implemented inside the decoder layer class using the same MultiHeadAttention class we created.
The forward() method now requires two masks: the causal mask for the first attention stage, and the cross-attention mask, which can be the same padding mask used in the encoder.
Importantly, the variable y in this method represents the encoder outputs, passed as key and value arguments to the cross-attention mechanism. Meanwhile, the decoder flow, x, associated with generating the target sequence, now only acts as the attention query. After cross-attention in the forward pass, we pass through the feed-forward sublayers as normal.
6. Modifying DecoderTransformer
Because we altered the DecoderLayer class to add the cross-attention mask, we also need to make similar modifications to the forward method of the TransformerDecoder class, passing the encoder outputs, y, and cross_mask to each decoder layer in addition to x and the causal mask.
7. Encoder meets decoder
The final encoder outputs are now fed into every layer of the decoder for cross-attention.
8. Transformer head
Similar to decoder-only transformers, the model's output head consists of a linear layer followed by softmax activation, converting decoder outputs into next-word probabilities.
In our translation example, these are probabilities for different Spanish words to be generated next. In other language tasks, a different activation function instead of softmax may be required.
9. Everything brought together!
Here is the full transformer architecture we just learned to build. Let's use the classes we've built throughout the course to create an encoder-decoder Transformer class.
10. Everything brought together!
With all our classes defined to build the layer components, the layers themselves, and the transformer blocks,
we can create a Transformer class that encapsulates everything.
We initialize the transformer encoder and decoder stacks,
and define a two-stage forward pass: passing the input sequence, x, through the encoder, and then passing it to the decoder together with the encoder output, incorporating the cross-attention mask.
11. Let's practice!
Wow - what a journey and result! Time to put these final pieces into place.