Implementing the training step
In this exercise, you'll implement the training_step() method in a PyTorch Lightning module designed for an image classification task.
Your implementation should unpack a batch of images and labels, compute the model predictions via the forward pass, calculate the cross entropy loss, and log the training loss.
Deze oefening maakt deel uit van de cursus
Scalable AI Models with PyTorch Lightning
Oefeninstructies
- Ensure that you compute predictions using the forward pass.
- Calculate the cross entropy loss.
- Log the training loss.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
from torch.nn.functional import cross_entropy
def training_step(self, batch, batch_idx):
x, y = batch
# Ensure that you compute predictions using the forward pass
y_hat = ____
# Calculate the cross entropy loss
loss = ____
# Log the loss
self.____("train_loss", loss)
return loss