Gegevens splitsen met LightningDataModule
Je gaat de setup-methode in een LightningDataModule afmaken. Een juiste opsplitsing van de gegevensset zorgt ervoor dat het model getraind wordt op het ene deel en gevalideerd op een ander deel, zodat je overfitting voorkomt.
De dataset is al vooraf geïmporteerd.
Deze oefening maakt deel uit van de cursus
Schaalbare AI-modellen met PyTorch Lightning
Oefeninstructies
- Importeer
random_splitom de gegevensset op te splitsen in training en validatie. - Splits de gegevensset in training (80%) en validatie (20%) met behulp van
random_split.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
# Import libraries
import lightning.pytorch as pl
from torch.utils.data import ____
class SplitDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
# Split the dataset into training (80%) and validation (20%)
self.____, self.____ = random_split(dataset, [____, ____])