1. Nauka
  2. /
  3. Kursy
  4. /
  5. Głębokie uczenie dla obrazów z PyTorch

Connected

ćwiczenie

Pętla treningowa

Nadszedł czas, by cała praca włożona w definiowanie architektur modeli i funkcji straty przyniosła efekty – czas na trening! Twoim zadaniem jest zaimplementowanie i uruchomienie pętli treningowej GAN. Uwaga: po pierwszej partii danych umieszczono instrukcję break, aby uniknąć długiego czasu wykonania.

Oba optymizatory, disc_opt i gen_opt, zostały zainicjalizowane jako optymizatory Adam(). Funkcje do obliczania strat zdefiniowane wcześniej – gen_loss() i disc_loss() – są już dostępne. Przygotowano również dataloader.

Pamiętaj, że:

  • Argumenty disc_loss() to: gen, disc, real, cur_batch_size, z_dim.
  • Argumenty gen_loss() to: gen, disc, cur_batch_size, z_dim.

Instrukcje

100 XP
  • Oblicz stratę dyskryminatora za pomocą disc_loss(), przekazując jej generator, dyskryminator, próbkę prawdziwych obrazów, bieżący rozmiar partii oraz rozmiar szumu 16 – w tej kolejności – i przypisz wynik do d_loss.
  • Oblicz gradienty, używając d_loss.
  • Oblicz stratę generatora za pomocą gen_loss(), przekazując jej generator, dyskryminator, bieżący rozmiar partii oraz rozmiar szumu 16 – w tej kolejności – i przypisz wynik do g_loss.
  • Oblicz gradienty, używając g_loss.