Multi-input models

1. Multi-input models

Let's learn to build multi-input models!

2. Why multi-input?

Multi-input models, or models that accept more than one source of data, have many applications. First, we might want the model to use multiple information sources, such as two images of the same car to predict its model. Second, multi-modal models can work on different input types such as image and text to answer a question about the image. Next, in metric learning, the model learns whether two inputs represent the same object. Think about an automated passport control where the system compares our passport photo with a picture it takes of us. Finally, in self-supervised learning, the model learns data representation by learning that two augmented versions of the same input represent the same object. Multi-input models are everywhere!

3. Omniglot dataset

Throughout the chapter, we will be using the Omniglot dataset, a collection of images of 964 different handwritten characters from 30 different alphabets.

4. Character classification

Let's use the Omniglot dataset to build a two-input model to classify handwritten characters. The first input will be the image of the character, such as this Latin letter "k".

5. Character classification

The second input will the the alphabet that it comes from expressed as a one-hot vector.

6. Character classification

Both inputs will be processed separately, then we concatenate their representations.

7. Character classification

Finally a classification layer predicts one of the 964 classes. We need two elements to build such a model: a custom Dataset and an appropriate model architecture.

8. Two-input Dataset

Let's start with the custom Omniglot dataset. We set it up as a class based on torch Dataset. In the init method, we store transform and samples provided when instantiating the dataset as class attributes. Samples are tuples of three: image file path, alphabet as a one-hot vector, and target label as the character class index. In the exercises, samples will be provided. For personal projects, we would need to create them from data file paths. Next, we need to implement the len method that returns the number of samples. Finally, the getitem method returns one sample based on the index it receives as input. For the given index, we retrieve the sample and load the image using Image.open from PIL. The convert method with the argument "L" makes sure that the image is read as grayscale. Then, we transform the image and return a triplet: the transformed image, the alphabet vector, and the target label.

9. Tensor concatenation

Before we proceed to building the model, we need to understand tensor concatenation. torch.cat concatenates tensors along a specified dimension. We pass it the tensors and the dimension: for 2D tensors, 0 stands for "horizontal" and 1 stands for "vertical" concatenation.

10. Two-input architecture

It's time to define our two-input model. We start with defining a sub-network or layer to process our first input, the image. It should look familiar: a convolution, max pool, elu activation, flattened to a linear layer of shape 128 in the end. Next, we define a layer to process our second input, the alphabet vector. Its input size is 30, the number of alphabets, and we map it to an arbitrarily chosen output size of 8. Then, a classifier would accept input of size 128 plus 8 (image and alphabet outputs concatenated) and produce the output of size 964, the number of classes.

11. Two-input architecture

In the forward method, we pass each input through its corresponding layer. Then, we concatenate the outputs with torch.cat. Finally, we pass the result through the classifier layer and return.

12. Training loop

The training loop looks just like all the ones we have seen so far. The only difference is that now the training data consists of three items: the image, the alphabet vector, and the labels, and we pass the images and alphabets to the model.

13. Let's practice!

Let's build a multi-input model!