1. Learn
  2. /
  3. 课程
  4. /
  5. 使用 PyTorch 进行图像深度学习

Connected

道练习

为 RPN 和 R-CNN 定义损失函数

您计划训练一个同时包含 RPN 和 R-CNN 组件的目标检测模型。要开始训练,您需要分别为每个组件定义其对应的损失函数。

请回忆:RPN 组件要判断一个候选区域内是否存在目标,并预测该候选区域的边界框坐标。R-CNN 组件需要将目标分类到多个类别之一,同时预测最终的边界框坐标。

已导入 torch、torch.nn 并命名为 nn。

说明

100 XP
  • 定义 RPN 的分类损失函数,并将其赋给 rpn_cls_criterion。
  • 定义 RPN 的回归损失函数,并将其赋给 rpn_reg_criterion。
  • 定义 R-CNN 的分类损失函数,并将其赋给 rcnn_cls_criterion。
  • 定义 R-CNN 的回归损失函数,并将其赋给 rcnn_reg_criterion。