Local SGD con Accelerator
Ya has implementado la acumulación de gradientes y el gradient checkpointing para optimizar el uso de memoria en tu modelo de traducción automática. El entrenamiento sigue siendo algo lento, así que decides añadir Local SGD a tu bucle de entrenamiento para mejorar la eficiencia de comunicación entre dispositivos. ¡Construye el bucle de entrenamiento con Local SGD!
model, train_dataloader y accelerator ya están definidos, y LocalSGD está importado.
Este ejercicio forma parte del curso
Entrenamiento eficiente de modelos de IA con PyTorch
Instrucciones del ejercicio
- Establece
local_sgd_stepspara sincronizar los gradientes cada ocho pasos. - Ejecuta el gestor de contexto de Local SGD con
.step().
ejercicio interactivo práctico
Prueba este ejercicio completando este código de ejemplo.
# Set up a context manager to synchronize gradients every eight steps
with LocalSGD(accelerator=accelerator, model=model, local_sgd_steps=____, enabled=True) as local_sgd:
for batch in train_dataloader:
with accelerator.accumulate(model):
inputs, targets = batch["input_ids"], batch["labels"]
outputs = model(inputs, labels=targets)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Step the local SGD context manager
local_sgd.____()