Dividindo dados com o LightningDataModule
Você vai terminar o método “ setup
” em um “ LightningDataModule
”. A divisão certa dos conjuntos de dados garante que o modelo seja treinado em um subconjunto e validado em outro, evitando o sobreajuste.
O dataset
já foi pré-importado.
Este exercício faz parte do curso
Modelos de IA escaláveis com PyTorch Lightning
Instruções do exercício
- Importa o arquivo “
random_split
” pra dividir o conjunto de dados em treinamento e validação. - Divida o conjunto de dados em treinamento (80%) e validação (20%) usando
random_split
.
Exercício interativo prático
Experimente este exercício completando este código de exemplo.
# Import libraries
import lightning.pytorch as pl
from torch.utils.data import ____
class SplitDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
# Split the dataset into training (80%) and validation (20%)
self.____, self.____ = random_split(dataset, [____, ____])