Een train DataLoader maken
Nu we onze gegevensset hebben gesplitst, moeten we een data loader definiëren die tijdens het trainen batches data aanlevert. DataLoader laadt data efficiënt in het geheugen en maakt schudden mogelijk voor betere generalisatie. In deze oefening maak je de methode train_dataloader af.
Deze oefening maakt deel uit van de cursus
Schaalbare AI-modellen met PyTorch Lightning
Oefeninstructies
- Importeer de
DataLoader. - Retourneer een
DataLoaderdieself.train_datalaadt en schudden inschakelt voor betere generalisatie.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
# Import libraries
from torch.utils.data import ____
import lightning.pytorch as pl
class LoaderDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
self.train_data, self.val_data = random_split(dataset, [80, 20])
def train_dataloader(self):
# Complete DataLoader
return ____(____, batch_size=16, shuffle=____)