Local SGD con Accelerator
Ya has implementado la acumulación de gradientes y el gradient checkpointing para optimizar el uso de memoria de tu modelo de traducción automática. El entrenamiento sigue siendo algo lento, así que decides añadir local SGD a tu bucle de entrenamiento para mejorar la eficiencia de comunicación entre dispositivos. ¡Construye el bucle de entrenamiento con local SGD!
model, train_dataloader y accelerator ya están predefinidos, y LocalSGD está importado.
Este ejercicio forma parte del curso
Entrenamiento eficiente de modelos de IA con PyTorch
Instrucciones del ejercicio
- Configura un gestor de contexto para local SGD y sincroniza los gradientes cada ocho pasos.
- Ejecuta el gestor de contexto de local SGD con
.step().
Ejercicio interactivo práctico
Prueba este ejercicio y completa el código de muestra.
# Set up a context manager to synchronize gradients every eight steps
with ____(accelerator=accelerator, model=model, ____=____, enabled=True) as local_sgd:
for batch in train_dataloader:
with accelerator.accumulate(model):
inputs, targets = batch["input_ids"], batch["labels"]
outputs = model(inputs, labels=targets)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Step the local SGD context manager
____.____()