Using the TensorDataset class
In practice, loading your data into a PyTorch dataset will be one of the first steps you take in order to create and train a neural network with PyTorch.
The TensorDataset
class is very helpful when your dataset can be loaded directly as a NumPy array. Recall that TensorDataset()
can take one or more NumPy arrays as input.
In this exercise, you'll practice creating a PyTorch dataset using the TensorDataset class.
torch
and numpy
have already been imported for you, along with the TensorDataset
class.
This is a part of the course
“Introduction to Deep Learning with PyTorch”
Exercise instructions
- Create a TensorDataset using the
torch_features
and thetorch_target
tensors provided (in this order). - Return the last element of the dataset.
Hands-on interactive exercise
Have a go at this exercise by completing this sample code.
import numpy as np
import torch
from torch.utils.data import TensorDataset
np_features = np.array(np.random.rand(12, 8))
np_target = np.array(np.random.rand(12, 1))
torch_features = torch.tensor(np_features)
torch_target = torch.tensor(np_target)
# Create a TensorDataset from two tensors
dataset = ____
# Return the last element of this dataset
print(____)
This exercise is part of the course
Introduction to Deep Learning with PyTorch
Learn how to build your first neural network, adjust hyperparameters, and tackle classification and regression problems in PyTorch.
Training a deep learning model is an art, and to make sure our model is trained correctly, we need to keep track of certain metrics during training, such as the loss or the accuracy. We will learn how to calculate such metrics and how to reduce overfitting using an image dataset as an example.
Exercise 1: A deeper dive into loading dataExercise 2: Using the TensorDataset classExercise 3: From data loading to running a forward passExercise 4: Evaluating model performanceExercise 5: Writing the evaluation loopExercise 6: Calculating accuracy using torchmetricsExercise 7: Fighting overfittingExercise 8: Experimenting with dropoutExercise 9: Understanding overfittingExercise 10: Improving model performanceExercise 11: Implementing random searchExercise 12: Wrap-up videoWhat is DataCamp?
Learn the data skills you need online at your own pace—from non-coding essentials to data science and machine learning.