Añadir métodos a la clase MultiHeadAttention
En este ejercicio, vas a completar el resto de la clase MultiHeadAttention desde cero definiendo cuatro métodos:
.split_heads(): divide y transforma las incrustaciones de entrada entre las cabezas de atención.compute_attention(): calcula la atención por producto punto escalado y la multiplica por la matriz de valores.combine_heads(): transforma los pesos de atención de vuelta a la misma forma que las incrustaciones de entrada,x.forward(): llama a los otros métodos para pasar las incrustaciones de entrada por cada proceso
torch.nn se ha importado como nn, torch.nn.functional está disponible como F, y torch también está disponible.
Este ejercicio forma parte del curso
Modelos Transformer con PyTorch
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.d_model = d_model
self.head_dim = d_model // num_heads
self.query_linear = nn.Linear(d_model, d_model, bias=False)
self.key_linear = nn.Linear(d_model, d_model, bias=False)
self.value_linear = nn.Linear(d_model, d_model, bias=False)
self.output_linear = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
seq_length = x.size(1)
# Split the input embeddings and permute
x = x.____
return x.permute(0, 2, 1, 3)