1. Nauka
  2. /
  3. Kursy
  4. /
  5. Głębokie uczenie dla obrazów z PyTorch

Connected

ćwiczenie

Strata generatora

Zanim zaczniesz trenować sieć GAN, musisz zdefiniować funkcje straty dla generatora i dyskryminatora. Zacznij od tego pierwszego.

Przypomnij sobie, że zadaniem generatora jest tworzenie fałszywych obrazów, które zmylą dyskryminatora i skłonią go do sklasyfikowania ich jako prawdziwe. Generator ponosi więc stratę, gdy dyskryminator rozpoznaje wygenerowane obrazy jako fałszywe (etykieta 0).

Zdefiniuj funkcję gen_loss(), która oblicza stratę generatora. Przyjmuje cztery argumenty:

  • gen – model generatora
  • disc – model dyskryminatora
  • num_images – liczba obrazów w partii
  • z_dim – rozmiar wejściowego losowego szumu

Instrukcje

100 XP
  • Wygeneruj losowy szum o kształcie num_images na z_dim i przypisz go do noise.
  • Użyj generatora, aby wygenerować fałszywy obraz na podstawie noise i przypisz go do fake.
  • Uzyskaj predykcję dyskryminatora dla wygenerowanego fałszywego obrazu.
  • Oblicz stratę generatora, wywołując criterion na predykcjach dyskryminatora i tensorze jedynek o tym samym kształcie.