EmpezarEmpieza gratis

Local SGD con Accelerator

Ya has implementado la acumulación de gradientes y el gradient checkpointing para optimizar el uso de memoria en tu modelo de traducción automática. El entrenamiento sigue siendo algo lento, así que decides añadir Local SGD a tu bucle de entrenamiento para mejorar la eficiencia de comunicación entre dispositivos. ¡Construye el bucle de entrenamiento con Local SGD!

model, train_dataloader y accelerator ya están definidos, y LocalSGD está importado.

Este ejercicio forma parte del curso

Entrenamiento eficiente de modelos de IA con PyTorch

Ver curso

Instrucciones del ejercicio

  • Establece local_sgd_steps para sincronizar los gradientes cada ocho pasos.
  • Ejecuta el gestor de contexto de Local SGD con .step().

ejercicio interactivo práctico

Prueba este ejercicio completando este código de ejemplo.

# Set up a context manager to synchronize gradients every eight steps
with LocalSGD(accelerator=accelerator, model=model, local_sgd_steps=____, enabled=True) as local_sgd:
    for batch in train_dataloader:
        with accelerator.accumulate(model):
            inputs, targets = batch["input_ids"], batch["labels"]
            outputs = model(inputs, labels=targets)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            # Step the local SGD context manager
            local_sgd.____()
Editar y ejecutar código