10  Interpreting what convnets learn

This chapter covers

A fundamental problem when building a computer vision application is that of interpretability: why did our classifier think a particular image contained a refrigerator, when all we can see is a truck? This is especially relevant to use cases where deep learning is used to complement human expertise, such as medical imaging use cases. This chapter will familiarize you with a range of different techniques for visualizing what convnets learn and understanding the decisions they make.

It’s often said that deep learning models are “black boxes”: they learn representations that are difficult to extract and present in a human-readable form. Although this is partially true for certain types of deep learning models, it’s definitely not true for convnets. The representations learned by convnets are highly amenable to visualization, in large part because they’re representations of visual concepts. Since 2013, a wide array of techniques has been developed for visualizing and interpreting these representations. We won’t survey all of them, but we’ll cover three of the most accessible and useful ones:

For the first method—activation visualization—we’ll use the small convnet that we trained from scratch on the dogs vs. cats classification problem in chapter 8. For the next two methods, we’ll use a pretrained Xception model.

10.1 Visualizing intermediate activations

Visualizing intermediate activations consists of displaying the values returned by various convolution and pooling layers in a model, given a certain input (the output of a layer is often called its activation, the output of the activation function). This gives a view into how an input is decomposed into the different filters learned by the network. We want to visualize feature maps with three dimensions: width, height, and depth (channels). Each channel encodes relatively independent features, so the proper way to visualize these feature maps is by independently plotting the contents of every channel as a 2D image. Let’s start by loading the model that we saved in section 8.2:

library(keras3)

model <- load_model("convnet_from_scratch_with_augmentation.keras")
model
Model: "functional_3"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_layer_3 (InputLayer)      │ (None, 180, 180, 3)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ rescaling_1 (Rescaling)         │ (None, 180, 180, 3)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_11 (Conv2D)              │ (None, 178, 178, 32)   │           896 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_6 (MaxPooling2D)  │ (None, 89, 89, 32)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_12 (Conv2D)              │ (None, 87, 87, 64)     │        18,496 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_7 (MaxPooling2D)  │ (None, 43, 43, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_13 (Conv2D)              │ (None, 41, 41, 128)    │        73,856 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_8 (MaxPooling2D)  │ (None, 20, 20, 128)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_14 (Conv2D)              │ (None, 18, 18, 256)    │       295,168 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_9 (MaxPooling2D)  │ (None, 9, 9, 256)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_15 (Conv2D)              │ (None, 7, 7, 512)      │     1,180,160 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ global_average_pooling2d_3      │ (None, 512)            │             0 │
│ (GlobalAveragePooling2D)        │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout (Dropout)               │ (None, 512)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_3 (Dense)                 │ (None, 1)              │           513 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 4,707,269 (17.96 MB)
 Trainable params: 1,569,089 (5.99 MB)
 Non-trainable params: 0 (0.00 B)
 Optimizer params: 3,138,180 (11.97 MB)

Next, we’ll get an input image: a picture of a cat, not part of the images the network was trained on.

Listing 10.1: Preprocessing a single image
1img_path <- get_file(
  fname = "cat.jpg",
  origin = "https://img-datasets.s3.amazonaws.com/cat.jpg"
)
get_img_array <- function(img_path, target_size) {
  image <- img_path |>
2    image_load(target_size = target_size) |>
3    image_to_array()
4  dim(image) <- c(1, dim(image))
  image
}

img <- get_img_array(img_path, target_size = c(180, 180))
str(img)
1
Downloads a test image
2
Opens the image file and resizes it
3
Turns the image into an array of shape (180, 180, 3)
4
Adds a dimension to transform our array into a “batch” of a single sample. Its shape is now (1, 180, 180, 3)
 num [1, 1:180, 1:180, 1:3] 3 8 11 9 10 13 22 21 59 86 ...

Let’s display the picture (see figure 10.1).

Listing 10.2: Displaying the test picture
display_image <- function(x, ..., max = 255L, margin = 0) {
  par(mar = rep(margin, 4))

  x |> as.array() |> drop() |>
    as.raster(max = max) |>
    plot(..., interpolate = FALSE)
}
display_image(img)

The test cat picture

The test cat picture

To extract the feature maps we want to look at, we’ll create a Keras model that takes batches of images as input and outputs the activations of all convolution and pooling layers.

Listing 10.3: Instantiating a model that returns layer activations
is_conv_layer <- \(x) inherits(x, keras$layers$Conv2D)
is_pooling_layer <- \(x) inherits(x, keras$layers$MaxPooling2D)

layer_outputs <- list()
1for (layer in model$layers)
  if (is_conv_layer(layer) || is_pooling_layer(layer))
    layer_outputs[[layer$name]] <- layer$output

2activation_model <- keras_model(
  inputs = model$input,
  outputs = layer_outputs
)
1
Extracts the outputs of all Conv2D and MaxPooling2D layers and puts them in a named list
2
Creates a model that will return these outputs, given the model input

When fed an image input, this model returns the values of the layer activations in the original model, as a list. This is the first time you’ve encountered a multi-output model in this book in practice since you learned about them in chapter 7: until now, the models you’ve seen have had exactly one input and one output. This model has one input and nine outputs—one output per layer activation.

Listing 10.4: Using the model to compute layer activations
1activations <- predict(activation_model, img)
str(activations)
1
Returns a list of nine arrays: one array per layer activation
List of 9
 $ conv2d_11      : num [1, 1:178, 1:178, 1:32] 0 0 0 0 0 0 0 0 0 0 ...
 $ max_pooling2d_6: num [1, 1:89, 1:89, 1:32] 0 0 0 0 0 0 0 0 0 0 ...
 $ conv2d_12      : num [1, 1:87, 1:87, 1:64] 0 0 0 0 0 0 0 0 0 0 ...
 $ max_pooling2d_7: num [1, 1:43, 1:43, 1:64] 0 0 0 0 0 0 0 0 0 0 ...
 $ conv2d_13      : num [1, 1:41, 1:41, 1:128] 0 0 0 0 0 0 0 0 0 0 ...
  [list output truncated]

For instance, this is the activation of the first convolution layer for the cat image input:

first_layer_activation <- activations[[ names(layer_outputs)[1] ]]
dim(first_layer_activation)
[1]   1 178 178  32

It’s a 178 × 178 feature map with 32 channels. Let’s try plotting the sixth channel of the activation of the first layer of the original model (see figure 10.2).

Listing 10.5: Visualizing the sixth channel
plot_activations <- function(x, ...) {
  withr::local_par(list(mar = c(0,0,0,0)))

  x <- drop(as.array(x))
  if (sum(x) == 0)
    return(plot(as.raster("gray")))

  rotate <- function(x) t(apply(x, 2, rev))
  graphics::image(
    rotate(x), asp = 1, axes = FALSE, useRaster = TRUE,
    col = viridis::viridis(256), ...
  )
}
plot_activations(first_layer_activation[, , , 6])

Sixth channel of the activation of the first layer on the test cat picture

Sixth channel of the activation of the first layer on the test cat picture

This channel appears to encode a diagonal edge detector, but note that your own channels may vary because the specific filters learned by convolution layers aren’t deterministic.

Now, let’s plot a complete visualization of all the activations in the network (see figure 10.3). We’ll extract and plot every channel in each of the layer activations, and we’ll stack the results in one big grid, with channels stacked side by side.

Listing 10.6: Visualizing every channel in every intermediate activation
1for (layer_name in names(activations)) {
  layer_activation <- activations[[layer_name]]

2  .[.., n_features] <- dim(layer_activation)

3  par(mfrow = n2mfrow(n_features, asp = 1.75),
      mar = rep(.1, 4), oma = c(0, 0, 1.5, 0))

4  for (j in 1:n_features)
    plot_activations(layer_activation[, , , j])
5  title(main = layer_name, outer = TRUE)
}
1
Iterates over the activations (and the names of the corresponding layers)
2
The layer activation has shape (1, size, size, n_features).
3
Prepares a plotting grid layout for all channels using par()
4
Plots a single channel (or feature)
5
Adds a title for the full grid of activations for the layer

Every channel of every layer activation on the test cat picture

Every channel of every layer activation on the test cat picture

There are a few things to note here:

  • The first layer acts as a collection of various edge detectors. At that stage, the activations retain almost all of the information present in the initial picture.
  • As we go higher, the activations become increasingly abstract and less visually interpretable. They begin to encode higher-level concepts such as “cat ear” and “cat eye.” Higher representations carry increasingly less information about the visual contents of the image and increasingly more information related to the class of the image.
  • The sparsity of the activations increases with the depth of the layer: in the first layer, all filters are activated by the input image, but in the following layers, more and more filters are blank. This means the pattern encoded by the filter isn’t found in the input image.

We have just observed an important universal characteristic of the representations learned by deep neural networks: the features extracted by a layer become increasingly abstract with the depth of the layer. The activations of higher layers carry less and less information about the specific input being seen and more and more information about the target (in this case, the class of the image: cat or dog). A deep neural network effectively acts as an information distillation pipeline, with raw data going in (in this case, RGB pictures) and being repeatedly transformed so that irrelevant information is filtered out (for example, the specific visual appearance of the image), and useful information is magnified and refined (for example, the class of the image).

This is analogous to the way humans and animals perceive the world: after observing a scene for a few seconds, a human can remember which abstract objects were present in it (bicycle, tree) but can’t remember the specific appearance of these objects. In fact, if you tried to draw a generic bicycle from memory, chances are you couldn’t get it remotely right, even though you’ve seen thousands of bicycles in your lifetime (see, for example, figure 10.4). Try it right now: this effect is absolutely real. Your brain has learned to completely abstract its visual input—to transform it into high-level visual concepts while filtering out irrelevant visual details—making it tremendously difficult to remember how things around you look.

Left: Attempts to draw a bicycle from memory. Right: What a schematic bicycle should look like.

Left: Attempts to draw a bicycle from memory. Right: What a schematic bicycle should look like.

10.2 Visualizing convnet filters

Another easy way to inspect the filters learned by convnets is to display the visual pattern that each filter is meant to respond to. This can be done with gradient ascent in input space, directly optimizing the values of the input image of a convnet so as to maximize the response of a specific filter, starting from a blank input image. The resulting input image will be one that the chosen filter is maximally responsive to.

Let’s try this with the filters of the Xception model. The process is simple: we’ll build a loss function that maximizes the value of a given filter in a given convolution layer, and then we’ll use gradient ascent to adjust the values of the input image so as to maximize this activation value. This will be your second example of a low-level gradient-based loop: the first one in chapter 2 performed gradient descent to minimize a loss; here we’ll use gradient ascent to maximize an activation. We will show it for TensorFlow, PyTorch, and JAX.

First, let’s instantiate the Xception model trained on the ImageNet dataset. We can once again use the KerasHub library, exactly as we did in chapter 8:

py_require("keras-hub")
keras_hub <- import("keras_hub")

1model <- keras_hub$models$Backbone$from_preset(
  "xception_41_imagenet"
)
2preprocessor <- keras_hub$layers$ImageConverter$from_preset(
  "xception_41_imagenet",
  image_size = shape(180, 180)
)
1
Instantiates the feature extractor network from pretrained weights
2
Loads the matching preprocessing to scale our input images

We’re interested in the convolutional layers of the model: the Conv2D and SeparableConv2D layers. We need to know their names so we can retrieve their outputs. Let’s print their names in order of depth.

Listing 10.7: Printing the names of Xception convolutional layers
unlist(lapply(model$layers, \(layer) {
  if (inherits(layer, keras$layers$Conv2D) ||
      inherits(layer, keras$layers$SeparableConv2D))
    layer$name
}))

You’ll notice that the SeparableConv2D layers here are all named something like block6_sepconv1, block7_sepconv2, etc.—Xception is structured into blocks, each containing several convolutional layers.

Now, let’s create a second model that returns the output of a specific layer: a “feature extractor” model. Because our model is a Functional API model, it is inspectable: we can query the output of one of its layers and reuse it in a new model. No need to copy all of the Xception code.

Listing 10.8: Feature extractor model returning a specific output
1layer_name <- "block3_sepconv1"
2layer <- get_layer(model, name = layer_name)
feature_extractor <-
3  keras_model(inputs = model$input,
              outputs = layer$output)
1
We could replace this with the name of any layer in the Xception convolutional base.
2
The layer object we’re interested in
3
Uses model$input and layer$output to create a model that, given an input image, returns the output of our target layer

To use this model, we can simply call it on some input data, but we should be careful to apply our model-specific image preprocessing so that our images are scaled to the same range as the Xception pretraining data:

activation <- img |> preprocessor() |> feature_extractor()

Let’s use our feature extractor model to define a function that returns a scalar value quantifying how much a given input image “activates” a given filter in the layer. This is the loss function that we’ll maximize during the gradient ascent process:

1compute_loss <- function(image, filter_index) {
  activation <- feature_extractor(image)
2  filter_activation <- activation@r[, 3:-3, 3:-3, filter_index]
3  op_mean(filter_activation)
}
1
The loss function takes an image tensor and the index of the filter we consider (an integer).
2
We avoid border artifacts by involving only non-border pixels in the loss: we discard the first 2 pixels along the sides of the activation.
3
Returns the mean of the activation values for the filter
NoteThe difference between predict(model, x) and model(x)

In the previous chapter, we used predict(model, x) for feature extraction. Here, we’re using model(x). What gives?

Both y <- predict(model, x) and y <- model(x) (where x is an array of input data) mean “run the model on x and retrieve the output y.” But they aren’t exactly the same thing.

predict() loops over the data in batches (in fact, we can specify the batch size via predict(model, x, batch_size=64)) and extracts the array value of the outputs. It’s schematically equivalent to

predict <- function(model, x, batch_size = 32) {
  y <- list()
  for (x_batch in split_into_batches(x, batch_size)) {
    y_batch <- as.array(model(x_batch))
    y[[length(y)+1]] <- y_batch
  }
  unsplit_batches(y)
}

This means predict() calls can scale to very large arrays. Meanwhile, model(x) happens in memory and doesn’t scale. On the other hand, predict() is not differentiable: TensorFlow, PyTorch, and JAX cannot backpropagate through it.

Another practical difference is that predict() returns an R array, whereas model() returns a tensor object (or, in the case of a multi-output model, a list of R arrays and a list of tensor objects, respectively).

You should use model(x) when you need to retrieve the gradients of the model call. And you should use predict() if you just need the output value. In other words, always use predict() unless you’re in the middle of writing a low-level gradient descent loop (as we are now).

A non-obvious trick to help the gradient-ascent process go smoothly is to normalize the gradient tensor by dividing it by its L2 norm (the square root of the sum of the squares of the values in the tensor). This ensures that the magnitude of the updates done to the input image is always within the same range.

Let’s set up the gradient ascent step function. Anything that involves gradients requires calling backend-level APIs, such as GradientTape() in TensorFlow, $backward() in PyTorch, and jax$grad() in JAX. Let’s line up all the code snippets for each of the three backends, starting with TensorFlow.

10.2.1 Gradient ascent in TensorFlow

For TensorFlow, we can open a GradientTape() scope and compute the loss inside of it to retrieve the gradients we need. We’ll wrap it with tf_function() to speed up computation.

Listing 10.9: Stochastic gradient ascent in TensorFlow
library(tensorflow, exclude = c("set_random_seed", "shape"))

gradient_ascent_step <- tf_function(\(image, filter_index, learning_rate) {
  with(tf$GradientTape() %as% tape, {
1    tape$watch(image)
2    loss <- compute_loss(image, filter_index)
  })
3  grads <- tape$gradient(loss, image)
4  grads <- op_normalize(grads)
5  image + (learning_rate * grads)
})
1
Explicitly watches the image tensor, because it isn’t a TensorFlow Variable (only Variables are automatically watched in a gradient tape)
2
Computes the loss scalar, indicating how much the current image activates the filter
3
Computes the gradients of the loss with respect to the image
4
Applies the “gradient normalization trick”
5
Returns the image, moved a little in a direction that activates our target filter more strongly

10.2.2 Gradient ascent in PyTorch

In the case of PyTorch, we use loss$backward() and image$grad to obtain the gradients of the loss with respect to the input image, like this:

torch <- import("torch")
gradient_ascent_step <- function(image, filter_index, learning_rate) {
1  image <- image$clone()$detach()$requires_grad_(TRUE)
  loss <- compute_loss(image, filter_index)
  loss$backward()
  grads <- image$grad
  grads <- op_normalize(grads)
  image + (learning_rate * grads)
}
1
Creates a copy of “image” that we can get gradients for

No need to reset the gradients because the image tensor is re-created at each iteration.

10.2.3 Gradient ascent in JAX

In the case of JAX, we use jax$grad() to obtain a function that returns the gradients of the loss with respect to the input image:

jax <- import("jax")

grad_fn <- jax$grad(compute_loss)

gradient_ascent_step <- jax$jit(\(image, filter_index, learning_rate) {
  grads <- grad_fn(image, filter_index)
  grads <- op_normalize(grads)
  image + (learning_rate * grads)
})

10.2.4 The filter visualization loop

Now we have all the pieces. Let’s put them together in an R function that takes a filter index as input and returns a tensor representing the pattern that maximizes the activation of the specified filter in our target layer.

Listing 10.10: Function to generate filter visualizations
img_height <- img_width <- 200

generate_filter_pattern <- function(filter_index) {
1  iterations <- 30
2  learning_rate <- 10
3  image <- random_uniform(
    minval = 0.4, maxval = 0.6,
    shape = shape(1, img_width, img_height, 3)
  )

4  for (i in seq(iterations))
    image <- gradient_ascent_step(image, filter_index, learning_rate)

  image
}
1
Number of gradient ascent steps to apply
2
Amplitude of a single step
3
Initializes an image tensor with random values (the Xception model expects input values in the [0, 1] range, so here we pick a range centered on 0.5)
4
Repeatedly updates the values of the image tensor to maximize our loss function

The resulting image tensor is a floating-point array of shape (200, 200, 3), with values that may not be integers within [0, 255]. Hence, we need to post-process this tensor to turn it into a displayable image. We do so with the following straightforward utility function.

Listing 10.11: Utility function to convert a tensor into a valid image
deprocess_image <- function(image, crop = TRUE) {
1  image <- op_squeeze(image, axis = 1)
2  image <- image - op_mean(image)
  image <- image / op_std(image)
  image <- (image * 64) + 128
  image <- op_clip(image, 0, 255)
  if (crop) {
3    image <- image@r[26:-26, 26:-26, ]
  }
  op_cast(image, "uint8")
}
1
Drops the batch dim
2
Normalizes image values within the [0, 255] range
3
Center crops to avoid border artifacts

Let’s try it (see figure 10.5):

generate_filter_pattern(filter_index = 3L) |>
  deprocess_image() |>
  display_image()

Pattern that the third channel in layer block3_sepconv1 responds to maximally

Pattern that the third channel in layer block3_sepconv1 responds to maximally

It seems that filter 3 in layer block3_sepconv1 is responsive to a horizontal-lines pattern that looks somewhat like water or fur.

Now the fun part: we can start visualizing every filter in the layer—and even every filter in every layer in the model (see figure 10.6).

Listing 10.12: Generating a grid of all filter response patterns in a layer
par(mfrow = c(8, 8))
for (i in seq_len(64)) {
  generate_filter_pattern(filter_index = i) |>
    deprocess_image() |>
    display_image(margin = .1)
}

Some filter patterns for layers block2_sepconv1, block4_sepconv1, and block8_sepconv1.

Some filter patterns for layers block2_sepconv1, block4_sepconv1, and block8_sepconv1.

These filter visualizations tell us a lot about how convnet layers see the world: each layer in a convnet learns a collection of filters such that their inputs can be expressed as a combination of the filters. This is similar to how the Fourier transform decomposes signals onto a bank of cosine functions. The filters in these convnet filter banks become increasingly complex and refined as we go higher in the model:

  • The filters from the first layers in the model encode simple directional edges and colors (or colored edges, in some cases).
  • The filters from layers a bit further up the stack, such as block4_sepconv1, encode simple textures made from combinations of edges and colors.
  • The filters in higher layers begin to resemble textures found in natural images: feathers, eyes, leaves, and so on.

10.3 Visualizing heatmaps of class activation

Here’s one last visualization technique—one that is useful for understanding which parts of a given image led a convnet to its final classification decision. This is helpful for “debugging” the decision process of a convnet, particularly in the case of a classification mistake (a problem domain called model interpretability). It can also allow us to locate specific objects in an image.

This general category of techniques is called class activation map (CAM) visualization, and it consists of producing heatmaps of class activation over input images. A class activation heatmap is a 2D grid of scores associated with a specific output class, computed for every location in any input image, indicating how important each location is with respect to the class under consideration. For instance, given an image fed into a dogs vs. cats convnet, CAM visualization would allow us to generate a heatmap for the class “cat,” indicating how cat-like different parts of the image are, and also a heatmap for the class “dog,” indicating how dog-like parts of the image are. The specific implementation we’ll use is described in Ramprasaath R. Selvaraju et al. (“Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization” (2019), https://arxiv.org/abs/1610.02391).

Grad-CAM consists of taking the output feature map of a convolution layer, given an input image, and weighting every channel in that feature map by the gradient of the class with respect to the channel. Intuitively, one way to understand this trick is that we’re weighting a spatial map of “how intensely the input image activates different channels” by “how important each channel is with regard to the class,” resulting in a spatial map of “how intensely the input image activates the class.”

Let’s demonstrate this technique using the pretrained Xception model. Consider the image of two African elephants shown in figure 10.7, possibly a mother and her calf, strolling on the savanna. We can start by downloading this image and converting it to an array.

Test picture of African elephants

Test picture of African elephants
1img_path <- get_file(
  fname = "elephant.jpg",
  origin = "https://img-datasets.s3.amazonaws.com/elephant.jpg"
)
2img <- img_path |> image_load() |> image_to_array() |> op_expand_dims(1)
1
Downloads the image and stores it locally under the path img_path
2
Loads, converts to an array, and adds a batch dimension

So far, we have used KerasHub only to instantiate a pretrained feature extractor network using the backbone class. For Grad-CAM, we need the entire Xception model, including the classification head: recall that Xception was trained on the ImageNet dataset with ~1 million labeled images belonging to 1,000 different classes.

KerasHub provides a high-level task API for common end-to-end workflows like image classification, text classification, image generation, and so on. A task wraps preprocessing, a feature extraction network, and a task-specific head into a single class that is easy to use. Let’s try it out:

model <- keras_hub$models$ImageClassifier$from_preset(
  "xception_41_imagenet",
1  activation = "softmax",
)
preds <- predict(model, img)
2str(preds)
1
We can configure the final activation of the classifier. Here we use a softmax activation, so our outputs are probabilities.
2
ImageNet has 1,000 classes, so each prediction from our classifier has 1,000 entries.
 num [1, 1:1000] 0.00000535 0.00002684 0.0000167 0.00001146 0.00001136 ...
decode_imagenet_predictions <- function(preds) {
  decoded <- keras_hub$utils$decode_imagenet_predictions(preds)
1  lapply(decoded, \(d) {
    .[class_name, score] <- purrr::list_transpose(d)
    tibble::tibble(class_name, score)
  })
}

decode_imagenet_predictions(preds)
1
Converts a nested list to a list of data frames for convenience
[[1]]
# A tibble: 5 × 2
  class_name          score
  <chr>               <dbl>
1 African_elephant 0.903   
2 tusker           0.0549  
3 Indian_elephant  0.0164  
4 triceratops      0.000297
5 Mexican_hairless 0.000189

The top five classes predicted for this image are as follows:

  • African elephant (with 90% probability)
  • Tusker (with 5% probability)
  • Indian elephant (with 2% probability)
  • Triceratops and Mexican hairless dog with less than 0.1% probability

The network has recognized the image as containing an undetermined quantity of African elephants. The entry in the prediction vector that was maximally activated is the one corresponding to the “African elephant” class, at index 387:

which.max(preds[1, ])
[1] 387

To visualize which parts of the image are the most African elephant–like, let’s set up the Grad-CAM process.

Note that we didn’t need to preprocess our image before calling the task model. That’s because the KerasHub ImageClassifier preprocesses inputs for us as part of predict(). Let’s preprocess the image ourselves so we can use the preprocessed inputs directly:

1img <- model$preprocessor(img)
1
KerasHub tasks like ImageClassifier have a preprocessor layer.

First, we create a model that maps the input image to the activations of the last convolutional layer.

Listing 10.13: Returning the last convolutional output
last_conv_layer_name <- "block14_sepconv2_act"
last_conv_layer <- model$backbone$get_layer(last_conv_layer_name)
last_conv_layer_model <- keras_model(model$inputs, last_conv_layer$output)

Second, we create a model that maps the activations of the last convolutional layer to the final class predictions.

Listing 10.14: Going from the last convolutional output to final predictions
classifier_input <- last_conv_layer$output
x <- classifier_input
for (layer_name in c("pooler", "predictions")) {
  layer <- model$get_layer(layer_name)
  x <- layer(x)
}
classifier_model <- keras_model(classifier_input, x)

Then we compute the gradient of the top predicted class for our input image with respect to the activations of the last convolution layer. Once again, needing to compute gradients means we have to use backend APIs.

10.3.1 Getting the gradient of the top class: TensorFlow version

Let’s start with the TensorFlow version, once again using GradientTape.

Listing 10.15: Computing top class gradients with TensorFlow
tf <- import("tensorflow")
get_top_class_gradients <- function(image_tensor) {
1  last_conv_layer_output <- last_conv_layer_model(image_tensor)
  with(tf$GradientTape() %as% tape, {
    tape$watch(last_conv_layer_output)
    preds <- classifier_model(last_conv_layer_output)
    top_pred_index <- op_argmax(preds@r[1])
2    top_class_channel <- preds@r[, top_pred_index]
  })

3  grads <- tape$gradient(top_class_channel, last_conv_layer_output)
  list(grads, last_conv_layer_output)
}
1
Computes activations of the last conv layer and makes the tape watch it
2
Retrieves the activation channel corresponding to the top predicted class
3
Gets the gradient of the top predicted class with regard to the output feature map of the last convolutional layer

10.3.2 Getting the gradient of the top class: PyTorch version

Next, here’s the PyTorch version, using $backward() and $grad.

Listing 10.16: Computing the top class gradients with PyTorch
torch <- import("torch")
get_top_class_gradients <- function(image_tensor) {
1  last_conv_layer_output <- last_conv_layer_model(image_tensor)$
2    clone()$detach()$requires_grad_(TRUE)

3  preds <- classifier_model(last_conv_layer_output)
  top_pred_index <-  op_argmax(preds@r[1])
  top_class_channel <- preds@r[, top_pred_index]
4  top_class_channel$backward()
  grads <- last_conv_layer_output$grad
  list(grads, last_conv_layer_output)
}
1
Computes activations of the last conv layer
2
Creates a copy of last_conv_layer_output that we can get gradients for
3
Retrieves the activation channel corresponding to the top predicted class
4
Gets the gradient of the top predicted class with regard to the output feature map of the last convolutional layer

10.3.3 Getting the gradient of the top class: JAX version

Finally, let’s do JAX. We define a separate loss-computation function that takes the final layer’s output and returns the activation channel corresponding to the top predicted class. We use this activation value as our loss, allowing us to compute the gradient.

Listing 10.17: Computing the top class gradients with JAX
jax <- import("jax")

1loss_fn <- function(last_conv_layer_output) {
  preds <- classifier_model(last_conv_layer_output)
  top_pred_index <- op_argmax(preds@r[1])
  top_class_channel <- preds[, top_pred_index]
2  top_class_channel@r[1]
}
3grad_fn <- jax$grad(loss_fn)

get_top_class_gradients <- function(image_tensor) {
  last_conv_layer_output <- last_conv_layer_model(image_tensor)
4  grads <- -grad_fn(last_conv_layer_output)
  list(grads, last_conv_layer_output)
}
1
Defines a separate loss function
2
Returns the activation value of the top-class channel
3
Creates a gradient function
4
Now, retrieving the gradient of the top-class channel is just a matter of calling the gradient function!

10.3.4 Displaying the class activation heatmap

Next, we apply pooling and importance weighting to the gradient tensor to obtain our heatmap of class activation.

Listing 10.18: Gradient pooling and channel importance weighting
1img <- img_path |> image_load() |> image_to_array() |> op_expand_dims(1)
img <- model$preprocessor(img)
.[grads, last_conv_layer_output] <- get_top_class_gradients(img)

pooled_grads <- op_mean(grads, axis = c(1, 2, 3), keepdims = TRUE)
2output <- last_conv_layer_output * pooled_grads
3heatmap <- op_mean(output@r[1], axis = -1)
1
Vector in which each entry is the mean intensity of the gradient for a given channel. It quantifies the importance of each channel with regard to the top predicted class. It has shape (1, 1, 1, 2048).
2
Multiplies each channel in the output of the last convolutional layer by “how important this channel is”. The output has shape (1, 10, 10, 2048).
3
The channel-wise mean of the resulting feature map is our heatmap of class activation, with shape (10, 10).

The result is shown in Figure 10.8.

Listing 10.19: Visualizing the heatmap
plot_activations(heatmap)

Standalone class activation heatmap.

Standalone class activation heatmap.

Finally, let’s generate an image that superimposes the original image on the heatmap we just obtained (see figure 10.9).

Listing 10.20: Superimposing the heatmap on the original picture
1palette <- hcl.colors(256, palette = "Spectral", alpha = .4)
heatmap <- as.array(-heatmap)
heatmap[] <- palette[cut(heatmap, 256)]
heatmap <- as.raster(heatmap)

2img <- image_load(img_path) |> image_to_array()
display_image(img)
3rasterImage(
  heatmap,
4  0, 0, ncol(img), nrow(img),
5  interpolate = FALSE
)
1
Converts the heatmap into a raster object with 40% opacity
2
Loads the original image, this time without resizing or preprocessing
3
Superimposes the heatmap over the original image
4
Passes ncol(img) and nrow(img) so that the heatmap, which has fewer pixels, is drawn to match the size of the original image
5
Passes interpolate=FALSE so we can clearly see the activation map pixel boundaries

African elephant class activation heatmap over the test picture

African elephant class activation heatmap over the test picture

This visualization technique answers two important questions:

  • Why did the network think this image contained an African elephant?
  • Where is the African elephant located in the picture?

In particular, it’s interesting to note that the ears of the elephant calf are strongly activated: this is probably how the network can tell the difference between African and Indian elephants.

10.4 Summary

  • Convnets process images by applying a set of learned filters. Filters from earlier layers detect edges and basic textures, and filters from later layers detect increasingly abstract concepts.
  • You can visualize both the pattern that a filter detects and a filter’s response map across an image.
  • You can use the Grad-CAM technique to visualize what area(s) in an image were responsible for a classifier’s decision.
  • Together, these techniques make convnets highly interpretable.