ComenzarEmpieza gratis

Añadir métodos a la clase MultiHeadAttention

En este ejercicio, vas a completar el resto de la clase MultiHeadAttention desde cero definiendo cuatro métodos:

  • .split_heads(): divide y transforma las incrustaciones de entrada entre las cabezas de atención
  • .compute_attention(): calcula la atención por producto punto escalado y la multiplica por la matriz de valores
  • .combine_heads(): transforma los pesos de atención de vuelta a la misma forma que las incrustaciones de entrada, x
  • .forward(): llama a los otros métodos para pasar las incrustaciones de entrada por cada proceso

torch.nn se ha importado como nn, torch.nn.functional está disponible como F, y torch también está disponible.

Este ejercicio forma parte del curso

Modelos Transformer con PyTorch

Ver curso

Ejercicio interactivo práctico

Prueba este ejercicio y completa el código de muestra.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        self.query_linear = nn.Linear(d_model, d_model, bias=False)
        self.key_linear = nn.Linear(d_model, d_model, bias=False)
        self.value_linear = nn.Linear(d_model, d_model, bias=False)
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        seq_length = x.size(1)
        # Split the input embeddings and permute
        x = x.____
        return x.permute(0, 2, 1, 3)
Editar y ejecutar código