1. Learn
  2. /
  3. 课程
  4. /
  5. PyTorch 深度学习进阶

Connected

道练习

序列数据集

很好,您已经构建好了 create_sequences() 函数!现在用它为模型创建训练数据集。

与表格数据和图像数据类似,顺序数据也最好通过 torch 的 Dataset 和 DataLoader 传递给模型。要构建一个序列 Dataset,您将调用 create_sequences() 获取包含输入和目标的 NumPy 数组,并查看它们的形状。接着,将它们传入 TensorDataset,创建一个正式的 torch Dataset,并查看其长度。

您实现的 create_sequences() 以及包含训练数据的 DataFrame train_data 已提供。

说明

100 XP
  • 调用 create_sequences(),传入训练用的 DataFrame 和序列长度 24*4,并将结果赋给 X_train, y_train。
  • 调用 TensorDataset 定义 dataset_train,传入两个参数:由 create_sequences() 生成的输入和目标,二者都需从 NumPy 数组转换为浮点类型的张量。