Get startedGet started for free

Evaluating model performance

1. Evaluating model performance

We've done a lot of training. Now let's evaluate our models.

2. Training, validation and testing

In machine learning, data is split into training, validation, and test sets. Training data adjusts the model's parameters, such as weights and biases, validation data tunes hyperparameters like learning rate and momentum, and the test set evaluates the model's final performance. We'll track two key metrics: loss and accuracy during training and validation.

3. Calculating training loss

Let's begin with loss. Training loss is calculated by summing the loss across all batches in the training dataloader. At the end of each epoch, we compute the mean training loss by dividing the total loss by the number of batches. We begin by setting training_loss to zero. We iterate through the trainloader, run a forward pass, and compute the loss. As usual, the model computes gradients and updates weights using backpropagation. We add each loss value to the total using .item(), which extracts the numerical value from a tensor. Since one epoch is a full pass through the training dataloader, we compute the mean loss by dividing training_loss by the number of batches in the trainloader.

4. Calculating validation loss

After each training epoch, we run a validation loop. First, we set the model to evaluation mode using .eval(), as some layers behave differently during training and validation. To improve efficiency, we use torch.no_grad(), which disables gradient calculations since we don't update weights during validation. We then iterate through the validation dataloader, run a forward pass, and compute the loss, summing it across batches. At the end of the epoch, we calculate the mean validation loss. Finally, we switch the model back to training mode with .train(), preparing it for the next training epoch.

5. Overfitting

Keeping track of training and validation loss helps us detect overfitting. When a model overfits, training loss keeps decreasing, but validation loss starts to rise. This means the model is learning the training data too well and won't perform well on new data.

6. Calculating accuracy with torchmetrics

Loss tells us how well a model is learning, but it doesn't always reflect how accurately it makes predictions. To measure that, we track accuracy using torchmetrics. For multi-class classification tasks, we create an accuracy metric with torchmetrics.Accuracy. As the model processes each batch, we update this metric using its predictions and the actual labels. Since the model outputs probabilities for multiple classes, we use argmax(dim=-1) to select the class with the highest probability. This converts one-hot encoded predictions into class indices before passing them to the metric. At the end of each epoch, we calculate the overall accuracy using .compute(). Finally, we reset the metric with .reset() to clear its state before the next epoch. This process is the same for both training and validation.

7. Let's practice!

Time for some practice!