1. 학습
  2. /
  3. 강의
  4. /
  5. PyTorch로 배우는 이미지 딥러닝

Connected

연습 문제

Generator loss

GAN을 학습시키기 전에, 생성기와 판별기 모두에 대한 손실 함수를 정의해야 해요. 먼저 생성기부터 시작해 보겠습니다.

생성기의 역할은 판별기가 진짜로 분류하도록 속일 수 있는 가짜 이미지를 만드는 것입니다. 따라서 생성기가 만든 이미지가 판별기에 의해 가짜(레이블 0)로 분류되면 생성기에는 손실이 발생합니다.

생성기 손실을 계산하는 gen_loss() 함수를 정의하세요. 이 함수는 네 가지 인수를 받습니다:

  • gen: 생성기 모델
  • disc: 판별기 모델
  • num_images: 배치 내 이미지 수
  • z_dim: 입력 무작위 노이즈의 차원 크기

지침

100 XP
  • 모양이 num_images by z_dim인 무작위 노이즈를 생성해 noise에 할당하세요.
  • 생성기를 사용해 noise로부터 가짜 이미지를 생성하고 fake에 할당하세요.
  • 생성된 가짜 이미지에 대한 판별기의 예측을 얻으세요.
  • 판별기의 예측과 같은 모양의 1로 이루어진 텐서를 대상으로 criterion을 호출해 생성기 손실을 계산하세요.