Loop de treinamento
Finalmente, todo o trabalho duro que você teve pra definir as arquiteturas do modelo e as funções de perda valeu a pena: é hora do treinamento! A tua tarefa é implementar e executar o ciclo de treinamento GAN. Observação: uma instrução “ break ” é colocada depois do primeiro lote de dados pra evitar um tempo de execução muito longo.
Os dois otimizadores, disc_opt e gen_opt, foram configurados como otimizadores de Adam(). As funções para calcular as perdas que você definiu anteriormente, gen_loss() e disc_loss(), estão disponíveis para você. Também preparamos um guia de viagem ( dataloader ) para você.
Lembre-se que:
disc_loss()Os argumentos de 's são:gen,disc,real,cur_batch_size,z_dim.gen_loss()Os argumentos de 's são:gen,disc,cur_batch_size,z_dim.
Este exercício faz parte do curso
Aprendizado profundo para imagens com PyTorch
Instruções do exercício
- Calcule a perda do discriminador usando
disc_loss()passando o gerador, o discriminador, a amostra de imagens reais, o tamanho do lote atual e o tamanho do ruído de16, nessa ordem, e atribua o resultado ad_loss. - Calcule gradientes usando
d_loss. - Calcule a perda do gerador usando
gen_loss()passando o gerador, o discriminador, o tamanho do lote atual e o tamanho do ruído de16, nessa ordem, e atribua o resultado ag_loss. - Calcule gradientes usando
g_loss.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break