Codificación de cómo los cambios de peso afectan a la precisión
¡Ahora podrás cambiar los pesos en una red real y ver cómo afectan a la precisión del modelo!
Echa un vistazo a la siguiente red neuronal:
Sus pesos se han precargado como weights_0
. Tu tarea en este ejercicio es actualizar un único peso en weights_0
para crear weights_1
, que ofrece una predicción perfecta (en la que el valor predicho es igual a target_actual
: 3).
Si es necesario, utiliza lápiz y papel para probar diferentes combinaciones. Utilizarás la función « predict_with_network()
», que toma un arreglo de datos como primer argumento y los pesos como segundo argumento.
Este ejercicio forma parte del curso
Introducción al aprendizaje profundo en Python
Instrucciones del ejercicio
- Crea un diccionario de pesos llamado «
weights_1
» en el que hayas cambiado 1 peso de «weights_0
» (solo necesitas hacer 1 edición en «weights_0
» para generar la predicción perfecta). - Obtener predicciones con los nuevos pesos utilizando la función «
predict_with_network()
» con «input_data
» y «weights_1
». - Calcula el error de los nuevos pesos restando
target_actual
demodel_output_1
. - Pulsa «Enviar respuesta» para ver cómo se comparan los errores.
Ejercicio interactivo práctico
Prueba este ejercicio completando el código de muestra.
# The data point you will make a prediction for
input_data = np.array([0, 3])
# Sample weights
weights_0 = {'node_0': [2, 1],
'node_1': [1, 2],
'output': [1, 1]
}
# The actual target value, used to calculate the error
target_actual = 3
# Make prediction using original weights
model_output_0 = predict_with_network(input_data, weights_0)
# Calculate error: error_0
error_0 = model_output_0 - target_actual
# Create weights that cause the network to make perfect prediction (3): weights_1
weights_1 = {'node_0': [____, ____],
'node_1': [____, ____],
'output': [____, ____]
}
# Make prediction using new weights: model_output_1
model_output_1 = ____
# Calculate error: error_1
error_1 = ____ - ____
# Print error_0 and error_1
print(error_0)
print(error_1)