Treinando um modelo GAN
Sua equipe na PyBooks avançou bem na construção do gerador de texto usando uma Generative Adversarial Network (GAN). Você já definiu com sucesso as redes geradora e discriminadora. Agora é hora de treiná-las. A etapa final é gerar alguns dados falsos e compará-los com os dados reais para ver o quanto sua GAN aprendeu. Usamos tensores como entrada e a saída tentará se assemelhar aos tensores de entrada. A equipe da PyBooks pode então usar esses dados sintéticos para análise de texto, já que as features manterão a mesma relação dos dados de texto.
O gerador e o discriminador foram inicializados e salvos em generator e discriminator, respectivamente.
As seguintes variáveis foram inicializadas no exercício:
seq_length = 5: Tamanho de cada sequência de dados sintéticosnum_sequences = 100: Total de sequências geradasnum_epochs = 50: Número de passagens completas pelo conjunto de dadosprint_every = 10: Frequência de exibição de resultados, mostrando a cada 10 épocas
Este exercício faz parte do curso
Deep Learning para Texto com PyTorch
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# Define the loss function and optimizer
criterion = nn.____()
optimizer_gen = ____(generator.parameters(), lr=0.001)
optimizer_disc = ____(discriminator.parameters(), lr=0.001)
for epoch in range(num_epochs):
for real_data in data:
# Unsqueezing real_data and prevent gradient recalculations
real_data = real_data.____(0)
noise = torch.rand((1, seq_length))
fake_data = generator(noise)
disc_real = discriminator(real_data)
disc_fake = discriminator(fake_data.____())
loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(disc_fake, torch.zeros_like(disc_fake))
optimizer_disc.zero_grad()
loss_disc.backward()
optimizer_disc.step()