1. Learn
  2. /
  3. 课程
  4. /
  5. 使用 PyTorch 进行图像深度学习

Connected

道练习

分类器模块

接下来,您的任务是创建一个分类器模块,用来替换原始的 VGG16 分类器。您决定使用包含两个全连接层、且中间带有 ReLU 激活函数的模块。

您在上一个练习中定义的 vgg_model 和 input_dim 已在工作区可用,且已导入 torch 和 torchvision.models。

说明

100 XP
  • 创建一个变量 num_classes,表示类别数量,假设只需要检测猫和狗。
  • 使用 nn.Sequential 创建一个顺序模块。
  • 创建一个线性层,并将 in_features 设为 input_dim。
  • 在分类器的最后一层中设置输出特征数。