SGD local con acelerador
Has implementado la acumulación de gradiente y la comprobación de gradiente para racionalizar el uso de memoria de tu modelo de traducción de idiomas. El entrenamiento sigue siendo un poco lento, así que decides añadir SGD local a tu bucle de entrenamiento para mejorar la eficacia de la comunicación entre dispositivos. ¡Construye el bucle de entrenamiento con SGD local!
Se han predefinido model
, train_dataloader
y accelerator
, y se ha importado LocalSGD
.
Este ejercicio forma parte del curso
Entrenamiento eficiente de modelos de IA con PyTorch
Instrucciones del ejercicio
- Configura un gestor de contexto para la SGD local, y sincroniza los gradientes cada ocho pasos.
- Paso al gestor de contexto SGD local.
Ejercicio interactivo práctico
Prueba este ejercicio completando el código de muestra.
# Set up a context manager to synchronize gradients every eight steps
with ____(accelerator=accelerator, model=model, ____=____, enabled=True) as local_sgd:
for batch in train_dataloader:
with accelerator.accumulate(model):
inputs, targets = batch["input_ids"], batch["labels"]
outputs = model(inputs, labels=targets)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Step the local SGD context manager
____.____()