Entrenamiento del modelo
En este ejercicio, entrenarás el modelo implementado anteriormente. ¿Sabías que el modelo de traducción automática basado en codificadores-decodificadores de Google tardó entre dos y cuatro días en entrenarse?
Para este ejercicio, utilizarás un pequeño conjunto de datos de 1500 frases (es decir, en_text
y fr_text
) para entrenar el modelo. Esta cantidad difícilmente será suficiente para obtener un buen rendimiento, pero el método seguirá siendo el mismo. Se trata de entrenar con más datos durante más tiempo. También se te ha proporcionado el modelo nmt
y la función sents2seqs()
que implementaste anteriormente. En este ejercicio, invertirás el texto del codificador para obtener un mejor rendimiento. Aquí, « en_x
» se refiere a la entrada del codificador, mientras que « de_x
» se refiere a la entrada del decodificador.
Este ejercicio forma parte del curso
Traducción automática con Keras
Instrucciones del ejercicio
- Obtén un único lote de entradas del codificador (frases en inglés desde el índice
i
hastai+bsize
) utilizando la funciónsents2seqs()
. Las entradas deben invertirse y codificarse con onehot. - Obtén un único lote de salidas del decodificador (frases en francés desde el índice
i
hastai+bsize
) utilizando la funciónsents2seqs()
. Las entradas deben estar codificadas en onehot. - Entrena el modelo con un único lote de datos que contenga
en_x
yde_y
. - Obtener las métricas de evaluación para
en_x
yde_y
evaluando el modelo con unbatch_size
debsize
.
Ejercicio interactivo práctico
Prueba este ejercicio completando el código de muestra.
n_epochs, bsize = 3, 250
for ei in range(n_epochs):
for i in range(0,data_size,bsize):
# Get a single batch of encoder inputs
en_x = ____('source', ____, onehot=____, reverse=____)
# Get a single batch of decoder outputs
de_y = sents2seqs('target', fr_text[____], onehot=____)
# Train the model on a single batch of data
nmt.____(____, ____)
# Obtain the eval metrics for the training data
res = nmt.____(____, de_y, batch_size=____, verbose=0)
print("{} => Train Loss:{}, Train Acc: {}".format(ei+1,res[0], res[1]*100.0))