1. Learn
  2. /
  3. 课程
  4. /
  5. PyTorch 深度学习进阶

Connected

道练习

训练多输出模型

在训练具有多个输出的模型时,必须确保正确地定义损失函数。

在这个例子中,模型会产生两个输出:一个是字母表的预测,另一个是字符的预测。每个输出都有对应的真实标签,这使您可以分别计算两个损失:一个来自字母表分类错误,另一个来自字符分类错误。由于这两种情况本质上都是多类别分类任务,因此每次都可以使用交叉熵损失(Cross-Entropy loss)。

然而,梯度下降一次只能优化一个损失函数。因此,您将把总体损失定义为字母表损失与字符损失之和。

说明

100 XP
  • 计算字母表分类损失,并将其赋给 loss_alpha。
  • 计算字符分类损失,并将其赋给 loss_char。
  • 计算总损失为两个部分损失之和,并将其赋给 loss。