ComenzarEmpieza gratis

Local SGD con Accelerator

Ya has implementado la acumulación de gradientes y el gradient checkpointing para optimizar el uso de memoria de tu modelo de traducción automática. El entrenamiento sigue siendo algo lento, así que decides añadir local SGD a tu bucle de entrenamiento para mejorar la eficiencia de comunicación entre dispositivos. ¡Construye el bucle de entrenamiento con local SGD!

model, train_dataloader y accelerator ya están predefinidos, y LocalSGD está importado.

Este ejercicio forma parte del curso

Entrenamiento eficiente de modelos de IA con PyTorch

Ver curso

Instrucciones del ejercicio

  • Configura un gestor de contexto para local SGD y sincroniza los gradientes cada ocho pasos.
  • Ejecuta el gestor de contexto de local SGD con .step().

Ejercicio interactivo práctico

Prueba este ejercicio y completa el código de muestra.

# Set up a context manager to synchronize gradients every eight steps
with ____(accelerator=accelerator, model=model, ____=____, enabled=True) as local_sgd:
    for batch in train_dataloader:
        with accelerator.accumulate(model):
            inputs, targets = batch["input_ids"], batch["labels"]
            outputs = model(inputs, labels=targets)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            # Step the local SGD context manager
            ____.____()
Editar y ejecutar código