Get startedGet started for free

Adding cross-attention to the decoder layer

To integrate the encoder and decoder stacks you've defined previously into an encoder-decoder transformer, you need to create a cross-attention mechanism to act as a bridge between the two.

The MultiHeadAttention class you defined previously is still available.

This exercise is part of the course

Transformer Models with PyTorch

View Course

Exercise instructions

  • Define a cross-attention mechanism (using MultiHeadAttention) and a third layer normalization (using nn.LayerNorm) in the __init__ method.
  • Complete the forward pass to add cross-attention to the decoder layer.

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        # Define cross-attention and a third layer normalization
        self.cross_attn = ____
        self.ff_sublayer = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = ____
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y, tgt_mask, cross_mask):
        self_attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_output))
        # Complete the forward pass
        cross_attn_output = self.____(____)
        x = self.norm2(x + self.dropout(____))
        ff_output = self.ff_sublayer(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x
Edit and Run Code