1. Learn
  2. /
  3. 课程
  4. /
  5. PyTorch 深度学习进阶

Connected

道练习

LSTM 网络

您已经知道,普通的 RNN 单元在实际中用得并不多。更常用、且能更好处理长序列的替代方案是长短期记忆(Long Short-Term Memory,LSTM)单元。本练习中,您将亲手构建一个 LSTM 网络!

与之前构建的 RNN 网络相比,最重要的实现差异在于:LSTM 具有两个隐藏状态,而不是一个。这意味着您需要额外初始化这个隐藏状态,并将其传入 LSTM 单元。

torch 和 torch.nn 已为您导入,请开始编码!

说明

100 XP
  • 在 .__init__() 方法中,定义一个 LSTM 层并将其赋给 self.lstm。
  • 在 forward() 方法中,用全零初始化第一个长期记忆隐藏状态 c0。
  • 在 forward() 方法中,将三个输入都传给 LSTM 层:当前时间步的输入,以及包含两个隐藏状态的元组。