Get startedGet started for free

Fine-tuning computer vision models

1. Fine-tuning computer vision models

Let's now fine-tune our computer vision models!

2. Purpose of fine-tuning vision models

Large vision models are pretrained on every day images with a limited number of classes. Fine-tuning is required in order to adapt these models to unseen data, such as new classes or new image modalities, such as X-rays.

3. Fine-tuning vision models

Here, we'll fine-tune a model for detecting AI-generated images. To do this, we'll adjust the model so that it correctly maps to the new predicted classes, configure the dataset for training, including preprocessing and resizing, configure the training parameters, then set the training going. Let's begin!

4. Model updates

To update the model with information about the new classes, we load the dataset, then use the .train_test_split() method to split the dataset for training and evaluation. Here, we've used 20% of the data for evaluation, and we set a seed to make the split reproducible. We extract the labels from the dataset using the .names attribute of the training feature labels, then use them to construct two dictionaries, id2label and label2id, to map IDs to labels and labels to IDs, respectively.

5. Model updates

We pass these mapping dictionaries as arguments to the .from_pretrained() method along with the model checkpoint. It's important to set the ignore_mismatched_sizes flag, as the model would usually calculate an output of length equal to the number of classes it was originally trained on. We want it to ignore this and use our new labels and IDs.

6. Dataset preparation

The new image data must be preprocessed in same way as the original training data and converted into tensors, so we start by loading the processor from the same model checkpoint. We use torchvision to set up the processing steps. We normalize pixel intensities using the mean and standard deviation from the .image_mean and .image_std attributes of the processor. The ToTensor() function is also required to convert the images into a PyTorch tensors. These two processing steps are combined using the Compose() function. The transform variable is essentially a recipe for the processing steps. We'll define a transforms() function for performing the actual conversion of the new data with the transformations we've defined. We create a "pixel_values" column containing the transformed images that will be inputted into the model. Finally, the .with_transforms() method attaches the transforms to the dataset.

7. Plotting transformed data

We can plot the pixel values with .imshow(). The .permute() method is required to send the color channel dimension from the beginning to the end of the tensor, as the dataset and plotting libraries read the channels in different orders.

8. Training

Now to configure the training process. The TrainingArguments class is designed to provide fully-functional PyTorch training, so there are many more arguments than are displayed here. We'll just focus on the key ones. The learning_rate is the size of correction when the model learns these new features; gradient_accumulation_steps specifies the amount of data the model looks at before making improvements, and num_train_epoch specifies the maximum number of training cycles. The Trainer class then takes the model, training arguments, training and evaluation datasets, preprocessor, and a data collator, which batches the data.

9. Evaluation

We can run the .predict() method on the trainer prior to training, and the accuracy shows the model practically guesses at random. After training for three epochs with the .train() method, we see a big improvement!

10. Let's practice!

Let's begin fine-tuning!