Recently, I came across a blog post by Adam Kosiorek entitled Attention in Neural Networks and How to Use It. This very readable post talks about visual attention models and how he used it in one of his papers. I decided to give it a try myself and here is the result: the most simplistic approach that still captures all the details to showcase this interesting field of research. As an example application, we will infer the position and size of digits in noisy images. The caveat is that we will not use this rich information to train the algorithm. Instead, we will only provide class labels (which digit is somewhere in the image).
The problem and how to approach it
If found the following definition online: Weakly supervised learning is a machine learning framework where the model is trained using examples that are only partially annotated or labeled.
Weakly supervised learning, in this context, refers to the problem of inferring not only the digit label but also the digit position and size via bounding box. This is much more information than we fed into learning the model which was done using only the image and the corresponding class label.
The core idea is to use, besides the main neural network used for classifying the input image, a (Gaussian) soft attention neural network that selects/weights certain parts of the image which we hope is the actual digit. Both networks are coupled together such that the attention network can be trained on just the class labels. To make sure that the bounding box will be tight, we gonna regularize the variance parameters of the soft attention network to be small.
The model was trained on a modified version of the ubiquitous MNIST dataset. For those of you who don’t know that is a dataset that was generated by Yann LeCun, Corinna Cortes and Christopher J.C. Burges and consists of 60.000 28×28 greyscale digits 0-9 for training and 10.000 for testing. Each of the images comes with a class label according to the digit depicted in the corresponding image.
However, the dataset itself is not useful for our weakly supervised learning setting and ergo, we need to modify it. First, let’s make the images much larger, say 100×100, and embed the 28*28 original digits on random positions. We gonna add some Gaussian background noise of varying levels to increase the problem complexity. Moreover, we add a number of randomly rotated smaller versions of digits of non-target classes. Finally, we also add digit noise, which means, that we alter a percentage of the original 28*28 digit. An example is given in the figure below.
The model is trained on a single K40 GPU and stored for inference on our online server. The app contains the trained version of our model and you can use it to build and classify the image. There are some parameters for image generation to play around, e.g. digit ground-truth label, positions x and y, the number of background noise digits, and the digit size. Whenever you change a parameter, the app is re-evaluating the input and show the resulting prediction. The prediction consists of the predicted class label, a bounding box, the Gaussian attention map, and the attention glimpse.
Source code and more
If you have any questions or remarks, please leave a comment below. Source code of the inference-based streamlit app and the Jupyter notebooks for training are available here. Please consider becoming a supporter if you enjoy our content. If you don’t enjoy the content, please consider supporting us anyways.