Membuat DataLoader pelatihan
Sekarang setelah himpunan data kita dibagi, kita perlu mendefinisikan data loader untuk menyediakan batch data selama pelatihan. DataLoader memuat data ke memori secara efisien dan memungkinkan pengacakan untuk generalisasi yang lebih baik. Dalam latihan ini, Anda akan melengkapi metode train_dataloader.
Latihan ini merupakan bagian dari kursus
Model AI yang Dapat Diskalakan dengan PyTorch Lightning
Instruksi latihan
- Impor
DataLoader. - Kembalikan
DataLoaderyang memuatself.train_data, dengan mengaktifkan pengacakan untuk meningkatkan generalisasi.
Latihan interaktif langsung praktik
Cobalah latihan ini dengan melengkapi kode contoh ini.
# Import libraries
from torch.utils.data import ____
import lightning.pytorch as pl
class LoaderDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
self.train_data, self.val_data = random_split(dataset, [80, 20])
def train_dataloader(self):
# Complete DataLoader
return ____(____, batch_size=16, shuffle=____)