De trainingstap implementeren
In deze oefening implementeer je de methode training_step() in een PyTorch Lightning-module voor een beeldclassificatietaak.
Je implementatie moet een batch met afbeeldingen en labels uitpakken, de modelvoorspellingen berekenen via de forward pass, de cross-entropy loss berekenen en de trainingsloss loggen.
Deze oefening maakt deel uit van de cursus
Schaalbare AI-modellen met PyTorch Lightning
Oefeninstructies
- Zorg dat je de voorspellingen berekent via de forward pass.
- Bereken de cross-entropy loss.
- Log de trainingsloss.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
from torch.nn.functional import cross_entropy
def training_step(self, batch, batch_idx):
x, y = batch
# Ensure that you compute predictions using the forward pass
y_hat = ____
# Calculate the cross entropy loss
loss = ____
# Log the loss
self.____("train_loss", loss)
return loss