SGD local com acelerador
Você implementou a acumulação de gradiente e o checkpointing de gradiente para otimizar o uso da memória do seu modelo de tradução de idiomas. O treinamento ainda está um pouco lento, então você decide adicionar o SGD local ao seu loop de treinamento para melhorar a eficiência da comunicação entre os dispositivos. Crie o loop de treinamento com o SGD local!
Os sites model
, train_dataloader
e accelerator
foram predefinidos e o site LocalSGD
foi importado.
Este exercício faz parte do curso
Treinamento eficiente de modelos de IA com PyTorch
Instruções do exercício
- Configure um gerenciador de contexto para SGD local e sincronize os gradientes a cada oito etapas.
- Passo a passo o gerenciador de contexto SGD local.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# Set up a context manager to synchronize gradients every eight steps
with ____(accelerator=accelerator, model=model, ____=____, enabled=True) as local_sgd:
for batch in train_dataloader:
with accelerator.accumulate(model):
inputs, targets = batch["input_ids"], batch["labels"]
outputs = model(inputs, labels=targets)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Step the local SGD context manager
____.____()