1. Mixed precision training with 8-bit Adam
Let's refine optimizer efficiency with one last optimizer.
2. Optimizers for training efficiency
We saw that Adafactor saves memory by storing fewer parameters.
3. Optimizers for training efficiency
8-bit Adam saves memory by storing parameters in low precision, matching AdamW's parameter count. The tradeoff is that some models may take longer to learn. Note that 8-bit Adam requires a GPU to run.
4. How does 8-bit Adam work?
8-bit Adam works by storing parameters in 8-bit floating point, or FP8, and calculates optimizations in 32-bit floating point, or FP32. Green indicates FP8, and yellow indicates FP32.
5. How does 8-bit Adam work?
The first step is a forward pass through the model in FP8.
6. How does 8-bit Adam work?
Next the optimizer adds past parameters, called weight decay, and computes gradients, in FP32.
7. How does 8-bit Adam work?
Then the optimizer computes the exponential moving average, or EMA, of the gradients and squared gradients, in FP32.
8. How does 8-bit Adam work?
It updates the model using these calculations in FP8.
9. How does 8-bit Adam save memory?
For each parameter gradient,
10. How does 8-bit Adam save memory?
the optimizer stores the EMA of the gradients
11. How does 8-bit Adam save memory?
and squared gradients (sometimes called the first and second moments). Here, each square is a parameter, and each color is a state. 8-bit Adam stores two states per parameter, represented by two colors for each square (green and red). Using FP8 precision, each parameter requires two bytes: one byte per state times two states. The estimated memory usage of 8-bit Adam is two bytes times the number of model parameters.
12. Estimate memory usage of 8-bit Adam
For example, when loading a Transformer model with about 65 million parameters, the estimated memory usage for 8-bit Adam is two bytes per parameter, totaling approximately 125 MB. We'll compare this with AdamW later.
13. Set up the 8-bit Adam optimizer
To set up 8-bit Adam, we import the bitsandbytes library and get_parameter_names() from transformers. First we declare TrainingArguments assuming default parameters. We want to apply weight decay to prevent overfitting, so we define decay_parameters to specify parameters for applying weight decay. We extract decay_parameters by calling get_parameter_names, telling it to ignore layer types of nn.LayerNorm, and we remove biases from decay_parameters, because weight decay doesn't apply to normalization layers and biases.
14. Set up the 8-bit Adam optimizer
Next we define groups of parameters to optimize based on whether to apply weight_decay and later feed these parameters into the optimizer. In optimizer_grouped_parameters, one group applies weight_decay to its parameters, but the other group does not, setting weight_decay to zero. We define Adam8bit using optimizer_grouped_parameters. beta1 and beta2 control decay rates of the first and second moments; higher values lead to more stable but slower training. epsilon avoids division by zero during parameter updates. Learning rate controls step size. We use default values from TrainingArguments for these parameters.
15. Implement 8-bit Adam with Trainer
After configuring the optimizer, we pass it to Trainer and monitor metrics such as loss, accuracy, and F1 score. Ideally, loss decreases while accuracy and F1 score increase.
16. Implement 8-bit Adam with Accelerator
With Accelerator, we prepare training objects, including the optimizer, for distributed training. The optimizer is integrated into the usual loop, and loss is monitored throughout training.
17. Compute memory usage of 8-bit Adam
To measure memory savings, we use the compute_optimizer_size function from before with AdamW. 8-bit Adam, with the same number of parameters, uses just one-fourth of the memory of AdamW.
18. Let's practice!
Time to practice!