Creating a train DataLoader
Now that we have split our dataset, we need to define a data loader to provide batches of data during training. DataLoader
efficiently loads data into memory and allows shuffling for better generalization. In this exercise, you'll complete the train_dataloader
method.
This exercise is part of the course
Scalable AI Models with PyTorch Lightning
Exercise instructions
- Import the
DataLoader
. - Return a
DataLoader
that loadsself.train_data
, enabling shuffling for better generalization.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
# Import libraries
from torch.utils.data import ____
import lightning.pytorch as pl
class LoaderDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
self.train_data, self.val_data = random_split(dataset, [80, 20])
def train_dataloader(self):
# Complete DataLoader
return ____(____, batch_size=16, shuffle=____)