Loop de treinamento
Finalmente, todo o trabalho duro que você teve pra definir as arquiteturas do modelo e as funções de perda valeu a pena: é hora do treinamento! A tua tarefa é implementar e executar o ciclo de treinamento GAN. Observação: uma instrução “ break
” é colocada depois do primeiro lote de dados pra evitar um tempo de execução muito longo.
Os dois otimizadores, disc_opt
e gen_opt
, foram configurados como otimizadores de Adam()
. As funções para calcular as perdas que você definiu anteriormente, gen_loss()
e disc_loss()
, estão disponíveis para você. Também preparamos um guia de viagem ( dataloader
) para você.
Lembre-se que:
disc_loss()
Os argumentos de 's são:gen
,disc
,real
,cur_batch_size
,z_dim
.gen_loss()
Os argumentos de 's são:gen
,disc
,cur_batch_size
,z_dim
.
Este exercício faz parte do curso
Aprendizado profundo para imagens com PyTorch
Instruções do exercício
- Calcule a perda do discriminador usando
disc_loss()
passando o gerador, o discriminador, a amostra de imagens reais, o tamanho do lote atual e o tamanho do ruído de16
, nessa ordem, e atribua o resultado ad_loss
. - Calcule gradientes usando
d_loss
. - Calcule a perda do gerador usando
gen_loss()
passando o gerador, o discriminador, o tamanho do lote atual e o tamanho do ruído de16
, nessa ordem, e atribua o resultado ag_loss
. - Calcule gradientes usando
g_loss
.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break