Modelo de clasificación multiclase
Con una plantilla para un modelo de clasificación binaria, ahora puedes basarte en ella para diseñar un modelo de clasificación multiclase. El modelo debe manejar diferentes números de clases mediante un parámetro, lo que te permite adaptar el modelo a una tarea de clasificación multiclase específica en el futuro.
Se han importado los paquetes torch
y torch.nn
como nn
. Todos los tamaños de imagen son de 64 x 64 píxeles.
Este ejercicio forma parte del curso
Aprendizaje profundo para imágenes con PyTorch
Instrucciones del ejercicio
- Define el método «
__init__
» incluyendo «self
» y «num_classes
» como parámetros. - Crea una capa totalmente conectada con un tamaño de entrada de
16*32*32
y el número de clasesnum_classes
como salida. - Crea una función de activación
softmax
condim=1
.
Ejercicio interactivo práctico
Prueba este ejercicio completando el código de muestra.
class MultiClassImageClassifier(nn.Module):
# Define the init method
def ____(____, ____):
super(MultiClassImageClassifier, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
# Create a fully connected layer
self.fc = ____(____, ____)
# Create an activation function
self.softmax = ____(____)