Get Started

Using the TensorDataset class

In practice, loading your data into a PyTorch dataset will be one of the first steps you take in order to create and train a neural network with PyTorch.

The TensorDataset class is very helpful when your dataset can be loaded directly as a NumPy array. Recall that TensorDataset() can take one or more NumPy arrays as input.

In this exercise, you'll practice creating a PyTorch dataset using the TensorDataset class.

torch and numpy have already been imported for you, along with the TensorDataset class.

This is a part of the course

“Introduction to Deep Learning with PyTorch”

View Course

Exercise instructions

  • Create a TensorDataset using the torch_features and the torch_target tensors provided (in this order).
  • Return the last element of the dataset.

Hands-on interactive exercise

Have a go at this exercise by completing this sample code.

import numpy as np
import torch
from torch.utils.data import TensorDataset

np_features = np.array(np.random.rand(12, 8))
np_target = np.array(np.random.rand(12, 1))

torch_features = torch.tensor(np_features)
torch_target = torch.tensor(np_target)

# Create a TensorDataset from two tensors
dataset = ____

# Return the last element of this dataset
print(____)

This exercise is part of the course

Introduction to Deep Learning with PyTorch

IntermediateSkill Level
4.3+
40 reviews

Learn how to build your first neural network, adjust hyperparameters, and tackle classification and regression problems in PyTorch.

Training a deep learning model is an art, and to make sure our model is trained correctly, we need to keep track of certain metrics during training, such as the loss or the accuracy. We will learn how to calculate such metrics and how to reduce overfitting using an image dataset as an example.

Exercise 1: A deeper dive into loading dataExercise 2: Using the TensorDataset class
Exercise 3: From data loading to running a forward passExercise 4: Evaluating model performanceExercise 5: Writing the evaluation loopExercise 6: Calculating accuracy using torchmetricsExercise 7: Fighting overfittingExercise 8: Experimenting with dropoutExercise 9: Understanding overfittingExercise 10: Improving model performanceExercise 11: Implementing random searchExercise 12: Wrap-up video

What is DataCamp?

Learn the data skills you need online at your own pace—from non-coding essentials to data science and machine learning.

Start Learning for Free