CNN Explainer - Interpreting Convolutional Neural Networks (1/N)
In today’s article, we are going to start a series of articles that aim to demystify the results of Convolutional Neural Networks (CNNs). CNNs are very successful in solving many Computer Vision tasks, but as they are Neural Networks after all, they may fall into the category of ‘black box’ systems, that don’t provide explanations of their predictions out of the box.
However, in this project, we are going explain their behaviors by visualizing learned weights, activation maps, and occlusion tests. By the end of this series, you will be able to reason not only about what CNNs are predicting but also about why are they doing so.
Why Interpretability Matters?
In the Machine Learning and Computer vision communities, there is an urban legend that in the 80s, the US military wanted to use artificial neural networks to automatically detect camouflaged tanks.
To do it, they took a bunch of pictures of trees and bushes with tanks behind them and some pictures with the same trees and bushes, but without the tanks. They believed that with such an approach, the system would be eventually able to detect whether an image contained a hidden tank or not.
The system seemed to be working fine and the results were really impressive. However, they wanted to be sure that the system is able to generalize so they took another set of pictures with and without tanks. This time the system performed badly and wasn’t able to detect whether images contained a hidden tank or not.
It turned out that all the pictures without tanks were taken on a cloudy day, while the ones with tanks, on a sunny day. As a result, the network learned to recognize weather instead of hidden tanks.
It’s unclear whether the above situation actually happened (for more info see this link), but it’s very easy to imagine that such situations may be possible, and easily overlooked in real-world scenarios.
Luckily that error was detected at all, but if the system had provided explanations for its predictions, such a situation would have been easily detected and fixed at a very early stage of development.
Interpretability might be helpful in multiple cases, but it’s especially important in life-critical medical applications, where there is a well-founded fear of blindly relying on sole neural network predictions.
Having information about why the network made a given prediction, might be as important as the prediction itself because firstly it could increase the confidence to follow-through, and secondly, it could enrich the medical staff with valuable insights that might be used elsewhere.
Nevertheless, interpretability comes with multiple benefits, and in this series, we are going to conduct a bunch of experiments that should bring us closer to mastering and understanding CNNs better.
Area Importance Heatmaps with Occlusions
There is a variety of approaches to make CNNs understandable, and the vast majority of them consist of evaluating their internals like weights and activations. This is a straightforward and promising idea, and we are going to try it in the next article of the series, but today we are going to try something simpler - occlusion tests.
The idea behind occlusion tests is very simple, yet surprisingly powerful, and we are going to go through the whole process step-by-step. Feel free to check the full codebase to follow along:
Permalink Failed to load latest commit information. CNN Explainer is PyTorch based project that aims to make CNN's…
We are going to start with 4 images from different domains, that represent classes that can be found in the imagenet dataset: a basketball, a pineapple, a submarine, and a zebra.
The second step is to run a CNN model pretrained on imagenet to verify that it correctly recognizes objects. We can easily do it with the following code snippet:
model = models.resnet18(pretrained=True)
model.eval()output = model(preprocessed_image)
output variable holds a vector of 1000 values.
Why 1000 you may ask.
That’s because there are 1000 classes in the imagenet dataset, so our model outputs a probability score per every class. The higher the value, the more our CNN model predicts that the input image belongs to the given class.
To find the top prediction, we need to find an index of the maximum value of the output tensor.
top_prediction = int(torch.argmax(output))
predicted_class = classes[top_prediction]
Then we check the mapping between 1000 classes and corresponding indices. You can find the full list there.
Finally, we run the resnet18 model on all 4 images:
Our CNN model correctly identified images that we provided as inputs. In other words, the model’s #1 predictions were correct. However, it’s often useful to check the model’s #2 predictions as well.
top_2_predictions = torch.topk(output, 2)
We’ll see why that’s beneficial later on.
The idea behind the occlusions test is pretty simple - we are going to occlude regions of an input image with a block of random noise. For example, if our input image is 224x224 and we decide on having a block size of 28, we are going to end up with (224/28)*(224/28)=64 images with unique occlusion positions to check.
For every image with a single occlusion block, we are going to run a CNN prediction and compare the results with the initial non-occluded output.
If the model’s prediction is very different (high loss) from the non-occluded output, it means that we occluded something important, which forced the model to make a mistake.
Having such information for every occluded block, we can draw an importance heatmap (where red signifies high importance and blue signifies low importance).
Let’s take a look at an importance heatmap for a basketball (input, heatmap, overlay).
The first thing that may come to one’s mind after looking at the above heatmap, is that the model recognized that a basketball has a round shape. That’s a correct observation, but that’s not everything.
Black and white ovals show areas that appeared to be very important for the model to decide that the input image is a basketball.
Black oval shows the area that follows a round shape, but it is actually out of basketball’s bounds. It’s a CNN model’s way of saying:
If I saw something in that area that is not a background, I would probably infer that the presented object is not perfectly round, thus it’s not a ball, so it cannot be a basketball.
White oval on the other hand covers the area that lays directly on a basketball, and it’s not some random part of a basketball. If you look closely, you’ll see that this area shows the seam configuration that is specific to basketballs. Without this area, CNN model could have predicted that this object is a volleyball (which was the model’s second prediction).
To wrap this up, for our CNN model, a round shape makes (black oval) input object a ball, and a specific alignment of seams (white oval) makes it a basketball.
This example is not as straightforward as the previous one, but if you look closely, you’ll notice something interesting.
Black oval shows the submarine’s bridge, and white oval shows the area in the back of the submarine, but far above it.
Submarine’s bridge is pretty obvious and doesn’t require further explanations, but what about that mystical area in the white oval?
Do you already know what it might be?
If not, let’s check the model’s second prediction for the input image - it’s a grey whale…
The area in the white oval is a CNN model’s way of saying:
If I saw something in that area that is not water, I would probably infer that the presented object is a grey whale, because grey whales usually have their tails there.
There we go with another not-so-obvious example.
Black oval shows the area that makes a pineapple a pineapple. At least from our CNN’s point of view.
Why not the pineapple’s texture you may ask.
Pineapples indeed have a very specific texture relative to other objects and even to other fruits, but it’s not exactly unique.
If we look at CNN’s second prediction, we’ll see a custard apple there. I wasn’t aware of such a fruit before that experimented so I had to google it, and this is how it looks:
Custard apple’s texture is not exactly the same as of the pineapple’s, but it’s very similar, especially the pattern of flakes.
However, the area marked with the black oval shows the biggest discriminator between pineapples and the custard apples, is the crown leaves, and more specifically the area that connects the crown with the fruit of a pineapple.
Our final example - a zebra.
The white oval shows the area that covers stripe patterns specific to zebras, which distinguishes them from for example horses.
Black oval, on the other hand, is more mystical, but it’ll make immediate sense if we check model’s second prediction which is … an elephant.
It’s the CNN’s way of saying:
If I saw something in that area attached to the four-legged creature, I would probably infer that the presented object is an elephant, because elephants usually have their trunks there.
In this project, we showed that’s it’s possible to reason about CNNs and interpret the results without diving into their internals. However, such an approach is quite limited, so in the next part of the series, we are going to investigate CNN’s weights and activation maps to make CNN predictions even more interpretable.
Part 2: Visualizing Gradient Weighted Class Activations with GradCAM
Don’t forget to check the project’s github page.
Questions? Comments? Feel free to leave your feedback in the comments section or contact me directly at https://gsurma.github.io.
And don’t forget to 👏 if you enjoyed this article 🙂.