ComeçarComece de graça

Dividindo dados com o LightningDataModule

Você vai terminar o método “ setup ” em um “ LightningDataModule ”. A divisão certa dos conjuntos de dados garante que o modelo seja treinado em um subconjunto e validado em outro, evitando o sobreajuste.

O dataset já foi pré-importado.

Este exercício faz parte do curso

Modelos de IA escaláveis com PyTorch Lightning

Ver curso

Instruções do exercício

  • Importa o arquivo “ random_split ” pra dividir o conjunto de dados em treinamento e validação.
  • Divida o conjunto de dados em treinamento (80%) e validação (20%) usando random_split.

Exercício interativo prático

Experimente este exercício completando este código de exemplo.

# Import libraries 
import lightning.pytorch as pl
from torch.utils.data import ____

class SplitDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.train_data = None
        self.val_data = None
    def setup(self, stage=None):
        # Split the dataset into training (80%) and validation (20%)
        self.____, self.____ = random_split(dataset, [____, ____])
Editar e executar o código