Creating a train DataLoader
Now that we have split our dataset, we need to define a data loader to provide batches of data during training. DataLoader efficiently loads data into memory and allows shuffling for better generalization. In this exercise, you'll complete the train_dataloader method.
Deze oefening maakt deel uit van de cursus
Scalable AI Models with PyTorch Lightning
Oefeninstructies
- Import the
DataLoader. - Return a
DataLoaderthat loadsself.train_data, enabling shuffling for better generalization.
Praktische interactieve oefening
Probeer deze oefening eens door deze voorbeeldcode in te vullen.
# Import libraries
from torch.utils.data import ____
import lightning.pytorch as pl
class LoaderDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
self.train_data, self.val_data = random_split(dataset, [80, 20])
def train_dataloader(self):
# Complete DataLoader
return ____(____, batch_size=16, shuffle=____)