Aan de slagBegin gratis

Gegevens splitsen met LightningDataModule

Je gaat de setup-methode in een LightningDataModule afmaken. Een juiste opsplitsing van de gegevensset zorgt ervoor dat het model getraind wordt op het ene deel en gevalideerd op een ander deel, zodat je overfitting voorkomt.

De dataset is al vooraf geïmporteerd.

Deze oefening maakt deel uit van de cursus

Schaalbare AI-modellen met PyTorch Lightning

Bekijk cursus

Oefeninstructies

  • Importeer random_split om de gegevensset op te splitsen in training en validatie.
  • Splits de gegevensset in training (80%) en validatie (20%) met behulp van random_split.

Interactieve oefening met praktijkervaring

Probeer deze oefening door deze voorbeeldcode aan te vullen.

# Import libraries 
import lightning.pytorch as pl
from torch.utils.data import ____

class SplitDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.train_data = None
        self.val_data = None
    def setup(self, stage=None):
        # Split the dataset into training (80%) and validation (20%)
        self.____, self.____ = random_split(dataset, [____, ____])
Code bewerken en uitvoeren