Prepare for 8-bit Training
You wanted to begin RLHF
fine-tuning, but you kept running into out-of-memory errors. To address this, you decided to switch to 8-bit precision, which allows for more efficient fine-tuning, by leveraging the Hugging Face peft
library.
The following have been pre-imported:
AutoModelForCausalLM
fromtransformers
prepare_model_for_int8_training
frompeft
AutoModelForCausalLMWithValueHead
fromtrl
This exercise is part of the course
Reinforcement Learning from Human Feedback (RLHF)
Exercise instructions
- Load the pre-trained model and make sure to include the parameter for 8-bit precision.
- Use the
prepare_model_for_int8_training
function to make the model ready for LoRA-based fine-tuning. - Load the model with a value head for
PPO
training.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
model_name = "gpt2"
# Load the model in 8-bit precision
pretrained_model = AutoModelForCausalLM.from_pretrained(
model_name,
____=True
)
# Prepare the model for fine-tuning
pretrained_model_8bit = ____(pretrained_model)
# Load the model with a value head
model = ____.from_pretrained(pretrained_model_8bit)