Recently, I came across a blog post by Adam Kosiorek entitled Attention in Neural Networks and How to Use It. This very readable post talks about visual attention models and how he used them in one of his papers. I decided to give it a try myself and here is the result: the most simplistic approach that still captures all the details to showcase this interesting field of research. As an example application, we will infer the position and size of digits in noisy images. The caveat is that we will not use this rich information to train the algorithm. Instead, we will only provide class labels (which digit is somewhere in the image). The problem and how to approach it
If found the following definition online: Weakly supervised learning is a machine learning framework where the model is trained using examples that are only partially annotated or labeled.
Weakly supervised learning, in this context, refers to the problem of inferring not only the digit label but also the digit position and size via bounding box. This is much more information than we fed into learning the model which was done using only the image and the corresponding class label.
The core idea is to use, besides the main neural network used for classifying the input image, a (Gaussian) soft attention neural network that selects/weights certain parts of the image which we hope is the actual digit. Both networks are coupled together such that the attention network can be trained on just the class labels. To make sure that the bounding box will be tight, we gonna regularize the variance parameters of the soft attention network to be small. The dataset
The model was trained on a modified version of the ubiquitous MNIST dataset. For those of you who don’t know that is a dataset that was generated by Yann LeCun, Corinna Cortes, and Christopher J.C. Burges and consists of 60.000 28×28 greyscale digits 0-9 for training and 10.000 for testing. Each of the images comes with a class label according to the digit depicted in the corresponding image.
However, the dataset itself is not useful for our weakly supervised learning setting and ergo, we need to modify it. First, let’s make the images much larger, say 100×100, and embed the 2828 original digits on random positions. We gonna add some Gaussian background noise of varying levels to increase the problem complexity. Moreover, we add a number of randomly rotated smaller versions of digits of non-target classes. Finally, we also add digit noise, which means, that we alter a percentage of the original 2828 digit. An example is given in the figure below. These are 8 examples for the transformed mnist dataset. Each example comes with an integer label of the most prominent digit. No positional information is given.
The model is trained on a single K40 GPU and stored for inference on our online server. The app contains the trained version of our model and you can use it to build and classify the image. There are some parameters for image generation to play around with, e.g. digit ground-truth label, positions x and y, the number of background noise digits, and the digit size. Whenever you change a parameter, the app is re-evaluating the input and show the resulting prediction. The prediction consists of the predicted class label, a bounding box, the Gaussian attention map, and the attention glimpse.
If you have any questions or remarks, please leave a comment below. Source code of the inference-based streamlit app and the Jupyter notebooks for training are available here.