Get startedGet started for free

PyTorch Dataset

Time to refresh your PyTorch Datasets knowledge!

Before model training can commence, you need to load the data and pass it to the model in the right format. In PyTorch, this is handled by Datasets and DataLoaders. Let's start with building a PyTorch Dataset for our water potability data.

In this exercise, you will define a class called WaterDataset to load the data from a CSV file. To do this, you will need to implement the three methods which PyTorch expects a Dataset to have:

  • .__init__() to load the data,
  • .__len__() to return data size,
  • .__getitem()__ to extract features and label for a single sample.

The following imports that you need have already been done for you:

import pandas as pd
from torch.utils.data import Dataset

This exercise is part of the course

Intermediate Deep Learning with PyTorch

View Course

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

class WaterDataset(Dataset):
    def __init__(self, csv_path):
        super().__init__()
        # Load data to pandas DataFrame
        df = ____
        # Convert data to a NumPy array and assign to self.data
        ____ = ____.____
Edit and Run Code