Creación de un DataLoader de tren
Ahora que hemos dividido nuestro conjunto de datos, necesitamos definir un cargador de datos para proporcionar lotes de datos durante el entrenamiento. DataLoader
carga datos de forma eficiente en la memoria y permite mezclarlos para obtener una mejor generalización. En este ejercicio, completarás el método « train_dataloader
».
Este ejercicio forma parte del curso
Modelos de IA escalables con PyTorch Lightning
Instrucciones del ejercicio
- Importa el archivo
DataLoader
. - Devuelve un objeto de tipo «
DataLoader
» que carga «self.train_data
» y habilita la aleatorización para una mejor generalización.
Ejercicio interactivo práctico
Prueba este ejercicio completando el código de muestra.
# Import libraries
from torch.utils.data import ____
import lightning.pytorch as pl
class LoaderDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
self.train_data, self.val_data = random_split(dataset, [80, 20])
def train_dataloader(self):
# Complete DataLoader
return ____(____, batch_size=16, shuffle=____)