1. Učit se
  2. /
  3. Kurzy
  4. /
  5. Deep Learning pro obrázky s PyTorchem

Connected

cvičení

Trénovací smyčka

A je to tady – veškerá tvoje práce na definování architektur modelu a ztrátových funkcí teď přijde ke slovu: čas trénovat! Tvým úkolem je implementovat a spustit trénovací smyčku GAN. Poznámka: po první dávce dat je vložen příkaz break, aby nedocházelo k příliš dlouhému běhu.

Oba optimizéry, disc_opt a gen_opt, jsou inicializovány jako optimizéry Adam(). K dispozici máš také funkce pro výpočet ztrát, které jsi definoval/a dříve: gen_loss() a disc_loss(). Připraven je i dataloader.

Připomeň si:

  • Argumenty disc_loss() jsou: gen, disc, real, cur_batch_size, z_dim.
  • Argumenty gen_loss() jsou: gen, disc, cur_batch_size, z_dim.

Pokyny

100 XP
  • Vypočítej ztrátu diskriminátoru pomocí disc_loss() – předej jí generátor, diskriminátor, vzorek reálných obrázků, aktuální velikost dávky a velikost šumu 16 (v tomto pořadí) a výsledek ulož do d_loss.
  • Vypočítej gradienty pomocí d_loss.
  • Vypočítej ztrátu generátoru pomocí gen_loss() – předej jí generátor, diskriminátor, aktuální velikost dávky a velikost šumu 16 (v tomto pořadí) a výsledek ulož do g_loss.
  • Vypočítej gradienty pomocí g_loss.