Implementing the validation step
Once we trained a neural network model we need to monitor its performance during training. Using PyTorch Lightning, implement the validation_step()
method to calculate and log the validation loss at each epoch.
This exercise is part of the course
Scalable AI Models with PyTorch Lightning
Exercise instructions
- Compute predictions using the model on input batch.
- Calculate validation loss using
F.cross_entropy()
. - Log the validation loss with
self.log()
asval_loss
.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
import torch.nn.functional as F
def validation_step(self, batch, batch_idx):
x, y = batch
# Compute predictions using the model
preds = ____(x)
# Calculate validation loss
loss = F.____(preds, y)
# Log the validation loss
self.____('val_loss', loss)