Writing a training loop
In scikit-learn
, the training loop is wrapped in the .fit()
method, while in PyTorch, it's set up manually. While this adds flexibility, it requires a custom implementation.
In this exercise, you'll create a loop to train a model for salary prediction.
The show_results()
function is provided to help you visualize some sample predictions.
The package imports provided are: pandas as pd
, torch
, torch.nn
as nn
, torch.optim
as optim
, as well as DataLoader
and TensorDataset
from torch.utils.data
.
The following variables have been created: num_epochs
, containing the number of epochs (set to 5); dataloader
, containing the dataloader; model
, containing the neural network; criterion
, containing the loss function, nn.MSELoss()
; optimizer
, containing the SGD optimizer.
This is a part of the course
“Introduction to Deep Learning with PyTorch”
This exercise is part of the course
Introduction to Deep Learning with PyTorch
Learn how to build your first neural network, adjust hyperparameters, and tackle classification and regression problems in PyTorch.
To train a neural network in PyTorch, you will first need to understand the job of a loss function. You will then realize that training a network requires minimizing that loss function, which is done by calculating gradients. You will learn how to use these gradients to update your model's parameters, and finally, you will write your first training loop.
Exercise 1: Running a forward passExercise 2: Building a binary classifier in PyTorchExercise 3: From regression to multi-class classificationExercise 4: Using loss functions to assess model predictionsExercise 5: Creating one-hot encoded labelsExercise 6: Calculating cross entropy lossExercise 7: Using derivatives to update model parametersExercise 8: Accessing the model parametersExercise 9: Updating the weights manuallyExercise 10: Using the PyTorch optimizerExercise 11: Writing our first training loopExercise 12: Using the MSELossExercise 13: Writing a training loopWhat is DataCamp?
Learn the data skills you need online at your own pace—from non-coding essentials to data science and machine learning.