Creare un DataLoader di training
Ora che abbiamo suddiviso il nostro insieme di dati, dobbiamo definire un data loader per fornire i batch durante l'addestramento. DataLoader carica i dati in memoria in modo efficiente e consente lo shuffling per una migliore generalizzazione. In questo esercizio completerai il metodo train_dataloader.
Questo esercizio fa parte del corso
Modelli di AI scalabili con PyTorch Lightning
Istruzioni dell'esercizio
- Importa
DataLoader. - Ritorna un
DataLoaderche carichiself.train_data, abilitando lo shuffling per una migliore generalizzazione.
Esercizio pratico interattivo
Prova a risolvere questo esercizio completando il codice di esempio.
# Import libraries
from torch.utils.data import ____
import lightning.pytorch as pl
class LoaderDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
self.train_data, self.val_data = random_split(dataset, [80, 20])
def train_dataloader(self):
# Complete DataLoader
return ____(____, batch_size=16, shuffle=____)