1. Learn
  2. /
  3. 课程
  4. /
  5. PyTorch 深度学习进阶

Connected

道练习

PyTorch 数据集

来复习一下您对 PyTorch Datasets 的理解!

在开始训练模型之前,您需要先加载数据,并以正确的格式传给模型。在 PyTorch 中,这由 Dataset 和 DataLoader 负责。我们先为饮用水可饮用性数据构建一个 PyTorch 数据集。

在本练习中,您将定义一个名为 WaterDataset 的类,用于从 CSV 文件加载数据。为此,您需要实现 PyTorch 期望 Dataset 具备的 3 个方法:

  • .__init__() 用于加载数据,
  • .__len__() 返回数据大小,
  • .__getitem()__ 为单个样本提取特征和标签。

以下导入已为您完成:

import pandas as pd
from torch.utils.data import Dataset

说明 1 / 共 3 个

undefined XP
    1
    2
    3
  • 在 .__init__() 方法中,从 csv_path 加载数据到一个 pandas DataFrame,并将其赋给 df。
  • 将 df 转换为 NumPy 数组,并将结果赋给 self.data。