ComeçarComece de graça

Loop de treinamento

Finalmente, todo o esforço que você teve para definir as arquiteturas dos modelos e as funções de loss vai valer a pena: é hora de treinar! Sua tarefa é implementar e executar o loop de treinamento da GAN. Observação: há um comando break após o primeiro batch de dados para evitar uma execução demorada.

Os dois otimizadores, disc_opt e gen_opt, já foram inicializados como otimizadores Adam(). As funções de cálculo das losses que você definiu antes, gen_loss() e disc_loss(), estão disponíveis. Um dataloader também está preparado para você.

Lembre-se de que:

  • Os argumentos de disc_loss() são: gen, disc, real, cur_batch_size, z_dim.
  • Os argumentos de gen_loss() são: gen, disc, cur_batch_size, z_dim.

Este exercício faz parte do curso

Deep Learning para Imagens com PyTorch

Ver curso

Instruções do exercício

  • Calcule a loss do discriminador usando disc_loss(), passando, nessa ordem, o gerador, o discriminador, a amostra de imagens reais, o tamanho atual do batch e o tamanho do ruído igual a 16, e atribua o resultado a d_loss.
  • Calcule os gradientes usando d_loss.
  • Calcule a loss do gerador usando gen_loss(), passando, nessa ordem, o gerador, o discriminador, o tamanho atual do batch e o tamanho do ruído igual a 16, e atribua o resultado a g_loss.
  • Calcule os gradientes usando g_loss.

Exercício interativo prático

Experimente este exercício completando este código de exemplo.

for epoch in range(1):
    for real in dataloader:
        cur_batch_size = len(real)
        
        disc_opt.zero_grad()
        # Calculate discriminator loss
        d_loss = ____
        # Compute gradients
        ____
        disc_opt.step()

        gen_opt.zero_grad()
        # Calculate generator loss
        g_loss = ____
        # Compute generator gradients
        ____
        gen_opt.step()

        print(f"Generator loss: {g_loss}")
        print(f"Discriminator loss: {d_loss}")
        break
Editar e executar o código