ComenzarEmpieza gratis

SGD local con acelerador

Has implementado la acumulación de gradiente y la comprobación de gradiente para racionalizar el uso de memoria de tu modelo de traducción de idiomas. El entrenamiento sigue siendo un poco lento, así que decides añadir SGD local a tu bucle de entrenamiento para mejorar la eficacia de la comunicación entre dispositivos. ¡Construye el bucle de entrenamiento con SGD local!

Se han predefinido model, train_dataloader y accelerator, y se ha importado LocalSGD.

Este ejercicio forma parte del curso

Entrenamiento eficiente de modelos de IA con PyTorch

Ver curso

Instrucciones del ejercicio

  • Configura un gestor de contexto para la SGD local, y sincroniza los gradientes cada ocho pasos.
  • Paso al gestor de contexto SGD local.

Ejercicio interactivo práctico

Prueba este ejercicio completando el código de muestra.

# Set up a context manager to synchronize gradients every eight steps
with ____(accelerator=accelerator, model=model, ____=____, enabled=True) as local_sgd:
    for batch in train_dataloader:
        with accelerator.accumulate(model):
            inputs, targets = batch["input_ids"], batch["labels"]
            outputs = model(inputs, labels=targets)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            # Step the local SGD context manager
            ____.____()
Editar y ejecutar código