1. 학습
  2. /
  3. 강의
  4. /
  5. PyTorch로 배우는 Transformer 모델

Connected

연습 문제

MultiHeadAttention 클래스에 메서드 추가하기

이 연습 문제에서는 네 가지 메서드를 정의해 MultiHeadAttention 클래스를 처음부터 완성해 볼 거예요.

  • .split_heads(): 입력 임베딩을 어텐션 헤드 수만큼 분할하고 변환해요.
  • .compute_attention(): 스케일드 도트 프로덕트 어텐션 가중치를 계산하고 value 행렬과 곱해요.
  • .combine_heads(): 어텐션 가중치를 입력 임베딩 x와 동일한 형태로 되돌려요.
  • .forward(): 위 메서드들을 호출해 입력 임베딩을 각 단계로 전달해요.

torch.nn은 nn으로 임포트되어 있고, torch.nn.functional은 F로, torch도 사용 가능해요.

지침 1/4

undefined XP
    1
    2
    3
    4
  • 입력 임베딩 x를 어텐션 헤드 수로 분할하기 위해 (batch_size, seq_length, self.num_heads, self.head_dim) 형태로 리쉐이프하세요.