7  A deep dive on Keras

This chapter covers

You’re starting to have some experience with Keras. You’re familiar with the Sequential model, Dense layers, and built-in APIs for training, evaluation, and inference: compile(), fit(), evaluate(), and predict(). You even learned in chapter 3 how to inherit from the Layer class to create custom layers, and how to use the gradient APIs in TensorFlow, JAX, and PyTorch to implement a step-by-step training loop.

In the coming chapters, we’ll dig into computer vision, timeseries forecasting, natural language processing, and generative deep learning. These complex applications will require much more than a Sequential architecture and the default fit() loop. So let’s first turn you into a Keras expert! In this chapter, you’ll get a complete overview of the key ways to work with Keras APIs: everything you’re going to need to handle the advanced deep learning use cases you’ll encounter next.

7.1 A spectrum of workflows

The design of the Keras API is guided by the principle of progressive disclosure of complexity: make it easy to get started, yet make it possible to handle high-complexity use cases while requiring only incremental learning at each step. Simple use cases should be easy and approachable, and arbitrarily advanced workflows should be possible: no matter how niche and complex the thing we want to do, there should be a clear path to it: a path that builds on things learned from simpler workflows. This means you can grow from beginner to expert and still use the same tools—only in different ways.

As such, there’s not a single “true” way to use Keras. Rather, Keras offers a spectrum of workflows, from the very simple to the very flexible. There are different ways to build Keras models and different ways to train them, answering different needs.

For instance, we have a range of ways to build models and an array of ways to train them, each representing a certain tradeoff between usability and flexibility. We can use Keras as we would use tidymodels—just calling fit() and letting the framework do its thing—or we can use it like base R—taking full control of every little detail.

Because all these workflows are based on shared APIs, such as Layer and Model, components from any workflow can be used in any other workflow: they can all talk to each other. This means everything you’re learning will still be relevant once you’ve become an expert. You can get started easily and then gradually dive into workflows where you’re writing more logic from scratch. You won’t have to switch to an entirely different framework as you go from student to researcher, or from data scientist to deep learning engineer.

This philosophy is not unlike that of R itself! Some languages offer only one way to write programs: for instance, object-oriented programming or functional programming. Meanwhile, R is a multi-paradigm language: it offers a range of possible usage patterns, which all work nicely together. This makes R suitable for a wide range of different use cases. Likewise, you can think of Keras as the R of deep learning: a user-friendly deep learning language that offers a variety of workflows for different user profiles.

7.2 Different ways to build Keras models

There are three APIs for building models in Keras, as shown in figure 7.1:

  • The Sequential model—The most approachable API; it’s basically a flat list. As such, it’s limited to simple stacks of layers.
  • The Functional API—Focuses on graph-like model architectures. It represents a nice mid-point between usability and flexibility, and as such, it’s the most commonly used model-building API.
  • Model subclassing—A low-level option where we write everything from scratch. This is ideal if you want full control over every little thing. However, you won’t get access to many built-in Keras features, and you will be more at risk of making mistakes.

Progressive disclosure of complexity for model building

Progressive disclosure of complexity for model building

7.2.1 The Sequential model

The simplest way to build a Keras model is using the Sequential model with keras_model_sequential(), which you already know about.

Listing 7.1: The Sequential class
library(keras3)

model <- keras_model_sequential() |>
  layer_dense(64, activation = "relu") |>
  layer_dense(10, activation = "softmax")

Note that it’s possible to build the same model incrementally, by passing it to a layer constructor. This appends a layer to the model’s layers list.

Listing 7.2: Incrementally building a Sequential model
model <- keras_model_sequential()
model |> layer_dense(64, activation = "relu")
model |> layer_dense(10, activation = "softmax")

You saw in chapter 3 that layers only get built (which is to say, their weights are created) when they are called for the first time. That’s because the shape of the layers’ weights depends on the shape of their input: until the input shape is known, they can’t be created. As such, the previous Sequential model does not have any weights until we call it on some data or call its build() method with an input shape.

Listing 7.3: Model that isn’t built and has no weights
1model$weights
1
At this point, the model isn’t built yet.
list()
Listing 7.4: Calling a model for the first time to build it
1model$build(input_shape = shape(NA, 3))
2str(model$weights)
1
Builds the model. Now the model will expect samples of shape (3). The NA in the input shape signals that the batch size could be anything.
2
Now we can retrieve the model’s weights.
List of 4
 $ :<Variable path=sequential_1/dense_2/kernel, shape=(3, 64), dtype=float32, value=[…]>
 $ :<Variable path=sequential_1/dense_2/bias, shape=(64), dtype=float32, value=[…]>
 $ :<Variable path=sequential_1/dense_3/kernel, shape=(64, 10), dtype=float32, value=[…]>
 $ :<Variable path=sequential_1/dense_3/bias, shape=(10), dtype=float32, value=[…]>

After the model is built, we can display its contents via the print() method, which comes in handy for debugging.

Listing 7.5: The summary method
model
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense_2 (Dense)                 │ (None, 64)             │           256 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_3 (Dense)                 │ (None, 10)             │           650 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 906 (3.54 KB)
 Trainable params: 906 (3.54 KB)
 Non-trainable params: 0 (0.00 B)

As you can see, this model happens to be named “sequential_1”. We can give names to everything in Keras—every model and every layer.

Listing 7.6: Naming models and layers with the name argument
model <- keras_model_sequential(name = "my_example_model")
model |> layer_dense(64, activation = "relu", name = "my_first_layer")
model |> layer_dense(10, activation = "softmax", name = "my_last_layer")
model$build(shape(NA, 3))
model
Model: "my_example_model"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ my_first_layer (Dense)          │ (None, 64)             │           256 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ my_last_layer (Dense)           │ (None, 10)             │           650 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 906 (3.54 KB)
 Trainable params: 906 (3.54 KB)
 Non-trainable params: 0 (0.00 B)

When building a Sequential model incrementally, it’s useful to be able to print a summary of what the current model looks like after we add each layer. But we can’t print a summary until the model is built! There’s a way to have a Sequential model be built on the fly: declare the shape of the model’s inputs in advance. We can do this via the input_shape argument.

Listing 7.7: Specifying the input shape of a model in advance
1model <- keras_model_sequential(input_shape = c(3))
model |> layer_dense(64, activation = "relu")
1
Uses input_shape to declare the shape of the inputs. Note that the argument must be the shape of each sample, not the shape of one batch.

Now we can use print() to follow how the output shape of the model changes as we add more layers:

model
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense_4 (Dense)                 │ (None, 64)             │           256 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 256 (1.00 KB)
 Trainable params: 256 (1.00 KB)
 Non-trainable params: 0 (0.00 B)
model |> layer_dense(10, activation = "softmax")
model
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense_4 (Dense)                 │ (None, 64)             │           256 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_5 (Dense)                 │ (None, 10)             │           650 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 906 (3.54 KB)
 Trainable params: 906 (3.54 KB)
 Non-trainable params: 0 (0.00 B)

This is a pretty common debugging workflow when dealing with layers that transform their inputs in complex ways, such as the convolutional layers you’ll learn about in chapter 8.

7.2.2 The Functional API

The Sequential model is easy to use, but its applicability is extremely limited: it can only express models with a single input and a single output, applying one layer after the other in a sequential fashion. In practice, it’s common to encounter models with multiple inputs (say, an image and its metadata), multiple outputs (different things we want to predict about the data), or a nonlinear topology.

In such cases, we build our model using the Functional API. This is what most Keras models you’ll encounter in the wild use. It’s fun and powerful—it feels like playing with LEGO bricks.

7.2.2.1 A simple example

Let’s start with something simple: the two-layer stack we used in the previous section. Its Functional API version looks like the following listing.

Listing 7.8: A simple Functional model with two Dense layers
inputs <- keras_input(shape = c(3), name = "my_input")
features <- inputs |> layer_dense(64, activation = "relu")
outputs <- features |> layer_dense(10, activation = "softmax")
model <- keras_model(inputs = inputs, outputs = outputs,
                     name = "my_functional_model")

Let’s go over this step by step. We start by creating a keras_input() (note that you can also give names to these input objects, like everything else):

inputs <- keras_input(shape = c(3), name = "my_input")

This inputs object holds information about the shape and dtype of the data the model will process:

1op_shape(inputs)
1
The model will process batches where each sample has shape (3). The number of samples per batch is variable (indicated by the NA batch size).
shape(NA, 3)
1op_dtype(inputs)
1
These batches will have dtype float32.
[1] "float32"

We call such an object a symbolic tensor. It doesn’t contain any actual data, but it encodes the specifications of the actual tensors of data that the model will see when we use it. It stands for future tensors of data.

Next, we create a layer and call it on the input:

features <- inputs |> layer_dense(64, activation = "relu")

All Keras layers can be called on real tensors of data or on these symbolic tensors. In the latter case, they return a new symbolic tensor, with updated shape and dtype information:

op_shape(features)
shape(NA, 64)

After obtaining the final outputs, we instantiate the model by specifying its inputs and outputs in the Model constructor:

outputs <- features |> layer_dense(10, activation = "softmax")
model <- keras_model(inputs = inputs, outputs = outputs,
                     name = "my_functional_model")

Here’s the summary of our model:

model
Model: "my_functional_model"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ my_input (InputLayer)           │ (None, 3)              │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_8 (Dense)                 │ (None, 64)             │           256 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_9 (Dense)                 │ (None, 10)             │           650 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 906 (3.54 KB)
 Trainable params: 906 (3.54 KB)
 Non-trainable params: 0 (0.00 B)

7.2.2.2 Multi-input, multi-output models

Unlike this toy model, most deep learning models don’t look like lists—they look like graphs. They may, for instance, have multiple inputs or multiple outputs. It’s for this kind of model that the Functional API really shines.

Let’s say we’re building a system to rank customer support tickets by priority and route them to the appropriate department. Our model has three inputs:

  • The title of the ticket (text input)
  • The text body of the ticket (text input)
  • Any tags added by the user (categorical input, assumed here to be multi-hot encoded)

We can encode the text inputs as arrays of ones and zeros of size vocabulary_size (see chapter 14 for detailed information about text encoding techniques).

The model also has two outputs:

  • The priority score of the ticket, a scalar between 0 and 1 (sigmoid output)
  • The department that should handle the ticket (a softmax over the set of departments)

We can build this model in a few lines with the Functional API. The Functional API is a simple, LEGO-like, yet very flexible way to define arbitrary graphs of layers like these.

Listing 7.9: A multi-input, multi-output Functional model
vocabulary_size <- 10000
num_tags <- 100
num_departments <- 4

1title <- keras_input(c(vocabulary_size), name = "title")
text_body <- keras_input(c(vocabulary_size), name = "text_body")
tags <- keras_input(c(num_tags), name = "tags")

features <-
2  layer_concatenate(c(title, text_body, tags)) |>
3  layer_dense(64, activation = "relu", name = "dense_features")

4priority <- features |>
  layer_dense(1, activation = "sigmoid", name = "priority")

department <- features |>
  layer_dense(num_departments, activation = "softmax", name = "department")

5model <- keras_model(
  inputs = c(title, text_body, tags),
  outputs = c(priority, department)
)
1
Defines model inputs
2
Combines input features into a single tensor, features, by concatenating them
3
Applies intermediate layer to recombine input features into richer representations
4
Defines model outputs
5
Creates the model by specifying its inputs and outputs

7.2.2.3 Training a multi-input, multi-output model

We can train our model in much the same way we would train a Sequential model: by calling fit() with lists of input and output data. These lists of data should respect the same order we passed to the Model constructor (keras_model()).

Listing 7.10: Training a model by providing lists of input and target arrays
num_samples <- 1280

random_uniform_array <- function(dim) {
  array(runif(prod(dim)), dim)
}

random_integer_array <- function(dim, minval = 0L, maxval = 1L) {
  array(sample(minval:maxval, prod(dim), replace = TRUE), dim)
}

1title_data     <- random_integer_array(c(num_samples, vocabulary_size))
text_body_data <- random_integer_array(c(num_samples, vocabulary_size))
tags_data      <- random_integer_array(c(num_samples, num_tags))

2priority_data <- random_uniform_array(c(num_samples, 1))
department_data <- random_integer_array(
  dim = c(num_samples, 1),
  maxval = num_departments - 1
)

model |> compile(
  optimizer = "adam",
  loss = c("mean_squared_error", "sparse_categorical_crossentropy"),
  metrics = list(c("mean_absolute_error"), c("accuracy"))
)

model |> fit(
  x = list(title_data, text_body_data, tags_data),
  y = list(priority_data, department_data),
  epochs = 1
)

model |> evaluate(
  x = list(title_data, text_body_data, tags_data),
  y = list(priority_data, department_data)
)

.[priority_preds, department_preds] <- model |> predict(
  list(title_data, text_body_data, tags_data)
)
1
Dummy input data
2
Dummy target data

If we don’t want to rely on input order (for instance, because we have many inputs or outputs), we can also use the names we gave the keras_input() objects and output layers, and pass data via named lists.

Listing 7.11: Training a model by providing dicts of input and target arrays
model |> compile(
  optimizer = "adam",
  loss = list(
    priority = "mean_squared_error",
    department = "sparse_categorical_crossentropy"
  ),
  metrics = list(
    priority = c("mean_absolute_error"),
    department = c("accuracy")
  )
)

model |> fit(
  x = list(title = title_data, text_body = text_body_data, tags = tags_data),
  y = list(priority = priority_data, department = department_data),
  epochs = 1
)

model |> evaluate(
  x = list(title = title_data, text_body = text_body_data, tags = tags_data),
  y = list(priority = priority_data, department = department_data)
)

.[priority_preds, department_preds] <- model |> predict(
  list(title = title_data, text_body = text_body_data, tags = tags_data)
)

7.2.2.4 The power of the Functional API: Access to layer connectivity

A Functional model is an explicit graph data structure. This makes it possible to inspect how layers are connected and reuse previous graph nodes (which are layer outputs) as part of new models. It also nicely fits the “mental model” that most researchers use when thinking about a deep neural network: a graph of layers.

This enables two important use cases: model visualization and feature extraction. Let’s take a look.

7.2.2.4.1 Plotting layer connectivity

Let’s visualize the connectivity of the model we just defined (the topology of the model). We can plot a Functional model as a graph with the plot() method for models, as shown in figure 7.2:

plot(model)

Plot generated by plot() on our ticket classifier model

Plot generated by plot() on our ticket classifier model

We can add to this plot the input and output shapes of each layer in the model, as well as layer names (rather than just layer types), which can be helpful during debugging (figure 7.3):

plot(model, show_shapes = TRUE, show_layer_names = TRUE)

Model plot with shape information added

Model plot with shape information added

The None in the tensor shapes represents the batch size: this model allows batches of any size. Depending on the context, an unspecified axis size may print as None, NULL, or NA—all three mean the same thing.

7.2.2.4.2 Feature extraction with a Functional model

Access to layer connectivity also means we can inspect and reuse individual nodes (layer calls) in the graph. The model property model$layers provides the list of layers that make up the model, and for each layer, we can query layer$input and layer$output.

Listing 7.12: Retrieving inputs or outputs of a layer in a Functional model
model$layers |> str()
List of 7
 $ :<InputLayer name=title, built=True>
 $ :<InputLayer name=text_body, built=True>
 $ :<InputLayer name=tags, built=True>
 $ :<Concatenate name=concatenate, built=True>
 $ :<Dense name=dense_features, built=True>
  [list output truncated]
model$layers[[4]]$input |> str()
List of 3
 $ :<KerasTensor shape=(None, 10000), dtype=float32, sparse=False, ragged=False, name=title>
 $ :<KerasTensor shape=(None, 10000), dtype=float32, sparse=False, ragged=False, name=text_body>
 $ :<KerasTensor shape=(None, 100), dtype=float32, sparse=False, ragged=False, name=tags>
model$layers[[4]]$output |> str()
<KerasTensor shape=(None, 20100), dtype=float32, sparse=False, ragged=False, name=keras_tensor_14>

This enables us to do feature extraction: creating models that reuse intermediate features from another model.

Let’s say we want to add another output to the model we previously defined: we also want to predict how long a given problem ticket will take to resolve—a kind of difficulty rating. We can do this via a classification layer over three categories: “quick,” “medium,” and “difficult.” We don’t need to re-create and retrain a model from scratch! We can just start from the intermediate features of our previous model, because we have access to them.

Listing 7.13: Creating a new model by reusing intermediate layer outputs
1features <- model$layers[[5]]$output
difficulty <- features |>
  layer_dense(3, activation = "softmax", name = "difficulty")

new_model <- keras_model(
  inputs = list(title, text_body, tags),
  outputs = list(priority, department, difficulty)
)
1
layers[[5]] is our intermediate Dense layer

Let’s plot our new model:

plot(new_model, show_shapes = TRUE, show_layer_names = TRUE)

Updated ticket classifier model with added difficulty output

Updated ticket classifier model with added difficulty output

7.2.3 Subclassing the Model class

The last model-building pattern you should know about is the most advanced one: Model subclassing. You learned in chapter 3 how to subclass the Layer class to create custom layers. Subclassing Model is pretty similar:

  1. In the initialize method, define the layers the model will use.
  2. In the call method, define the forward pass of the model, reusing the layers previously created.
  3. Instantiate your subclass and call it on data to create its weights.

7.2.3.1 Rewriting our previous example as a subclassed model

Let’s take a look at a simple example: we will reimplement the customer support ticket management model using a Model subclass.

Listing 7.14: A simple subclassed model
CustomerTicketModel <- new_model_class(
  classname = "CustomerTicketModel",

  initialize = function(num_departments) {
1    super$initialize()
2    self$concat_layer <- layer_concatenate()
    self$mixing_layer <- layer_dense(, 64, activation = "relu")
    self$priority_scorer <- layer_dense(, 1, activation = "sigmoid")
    self$department_classifier <- layer_dense(, num_departments,
                                              activation = "softmax")
  },

3  call = function(inputs) {
    .[title = title, text_body = text_body, tags = tags] <- inputs

    features <- list(title, text_body, tags) |>
      self$concat_layer() |>
      self$mixing_layer()
    priority <- features |> self$priority_scorer()
    department <- features |> self$department_classifier()
    list(priority, department)
  }
)
1
Don’t forget to call the super constructor!
2
Defines sublayers in the constructor
3
Defines the forward pass in the call() method

Once we’ve defined the model, we can instantiate it. Note that it will only create its weights the first time we call it on some data—much like Layer subclasses:

model <- CustomerTicketModel(num_departments = 4)
.[priority, department] <- model(list(
  title = title_data,
  text_body = text_body_data,
  tags = tags_data
))

So far, everything looks very similar to Layer subclassing, a workflow you’ve already encountered in chapter 3. What, then, is the difference between a Layer subclass and a Model subclass? It’s simple: a layer is a building block we use to create models, and a model is the top-level object that we will actually train, export for inference, etc. In short, a Model has fit(), evaluate(), and predict() methods. Layers don’t. Other than that, the two classes are virtually identical (another difference is that we can save a model to a file on disk—which we will cover in a few sections).

We can compile and train a Model subclass just like a Sequential or Functional model:

model |> compile(
  optimizer = "adam",
1  loss = c("mean_squared_error", "sparse_categorical_crossentropy"),
  metrics = c("mean_absolute_error", "accuracy")
)
model |> fit(
2  x = list(title = title_data,
           text_body = text_body_data,
           tags = tags_data),
  y = list(priority_data, department_data),
  epochs = 1
)
model |> evaluate(
  x = list(title = title_data,
           text_body = text_body_data,
           tags = tags_data),
  y = list(priority_data, department_data)
)
.[priority_preds, department_preds] <- model |> predict(
  list(title = title_data,
       text_body = text_body_data,
       tags = tags_data)
)
1
The structure of what we pass as the loss and metrics must match exactly what is returned by call(): we returned a list of two elements, so loss and metrics should be lists of two elements.
2
The structure of the input data must match exactly what is expected by the call() method, and the structure of the target data must match exactly what is returned by the call() method. Here, the input data must be a named list with three entries (title, text_body, and tags), and the target data must be a list of two elements.

The Model subclassing workflow is the most flexible way to build a model: it enables us to build models that cannot be expressed as directed acyclic graphs of layers: imagine, for instance, a model where the call() method uses layers inside a for loop, or even calls them recursively. Anything is possible—you’re in charge.

7.2.3.2 Beware: What subclassed models don’t support

This freedom comes at a cost: with subclassed models, we are responsible for more of the model logic, which means the potential error surface is much larger. As a result, we will have more debugging work to do. We are developing a new stateful object, not just snapping together LEGO bricks.

Functional and subclassed models are also substantially different in nature: a Functional model is an explicit data structure—a graph of layers, which we can view, inspect, and modify. Meanwhile, a subclassed model is a piece of bytecode: a class with a call() method that contains raw code. This is the source of the subclassing workflow’s flexibility—we can just code up whatever functionality we’d like—but it introduces new limitations.

For instance, because the way layers are connected to each other is hidden inside the body of the call() method, we cannot access that information. Calling print() will not display layer connectivity, and we cannot plot the model topology via plot(model). Likewise, if we have a subclassed model, we cannot access the nodes of the graph of layers to do feature extraction—because there is no graph. Once the model is instantiated, its forward pass becomes a complete black box.

7.2.4 Mixing and matching different components

Crucially, choosing one of these patterns—the Sequential model, the Functional API, or Model subclassing—does not lock us out of the others. All models in the Keras API can smoothly interoperate with each other, whether they’re Sequential models, Functional models, or subclassed models written from scratch. They’re all part of the same spectrum of workflows. For instance, we can use a subclassed layer or model in a Functional model.

Listing 7.15: Creating a Functional model that includes a subclassed model
Classifier <- new_model_class(
  classname = "Classifier",

  initialize = function(num_classes = 2) {
    super$initialize()
    if (num_classes == 2) {
      num_units <- 1
      activation <- "sigmoid"
    } else {
      num_units <- num_classes
      activation <- "softmax"
    }
    self$dense <- layer_dense(, num_units, activation = activation)
  },

  call = function(inputs) {
    self$dense(inputs)
  }
)

classifier <- Classifier(num_classes = 10)

inputs <- keras_input(shape = c(3))
outputs <- inputs |>
  layer_dense(64, activation = "relu") |>
  classifier()
model <- keras_model(inputs = inputs, outputs = outputs)

Conversely, we can use a Functional model as part of a subclassed layer or model.

Listing 7.16: Creating a subclassed model that includes a Functional model
inputs <- keras_input(shape = c(64))
outputs <- inputs |> layer_dense(1, activation = "sigmoid")
binary_classifier <- keras_model(inputs = inputs, outputs = outputs)

MyModel <- new_model_class(
  classname = "MyModel",

  initialize = function(num_classes = 2) {
    super$initialize()
    self$dense <- layer_dense(units = 64, activation = "relu")
    self$classifier <- binary_classifier
  },

  call = function(inputs) {
    inputs |>
      self$dense() |>
      self$classifier()
  }
)

model <- MyModel()

7.2.5 Remember: Use the right tool for the job

You’ve learned about the spectrum of workflows for building Keras models, from the simplest workflow (the Sequential model) to the most advanced one (model subclassing). When should you use one over the other? Each one has its pros and cons—pick the one most suitable for the job at hand.

In general, the Functional API provides a pretty good tradeoff between ease of use and flexibility. It also gives you direct access to layer connectivity, which is very powerful for use cases such as model plotting or feature extraction. If you can use the Functional API—that is, if your model can be expressed as a directed acyclic graph of layers—we recommend using it over model subclassing.

Going forward, all examples in this book will use the Functional API, simply because all the models we will work with are expressible as graphs of layers. We will, however, make frequent use of subclassed layers. In general, using Functional models that include subclassed layers provides the best of both worlds: high development flexibility while retaining the advantages of the Functional API.

7.3 Using built-in training and evaluation loops

The principle of progressive disclosure of complexity—access to a spectrum of workflows that go from dead easy to arbitrarily flexible, one step at a time—also applies to model training. Keras provides different workflows for training model: it can be as simple as calling fit() on our data or as advanced as writing a new training algorithm from scratch.

You are already familiar with the compile(), fit(), evaluate(), predict() workflow. As a reminder, it looks like the following listing.

Listing 7.17: Standard workflow: compile() / fit() / evaluate() / predict()
1get_mnist_model <- function() {
  inputs <- keras_input(shape = c(28 * 28))
  outputs <- inputs |>
    layer_dense(512, activation = "relu") |>
    layer_dropout(0.5) |>
    layer_dense(10, activation = "softmax")
  keras_model(inputs, outputs)
}

2.[.[images, labels], .[test_images, test_labels]] <- dataset_mnist()
images <- array_reshape(images, c(60000, 28 * 28)) / 255
test_images <- array_reshape(test_images, c(10000, 28 * 28)) / 255
train_images <- images[10001:60000, ]
val_images <- images[1:10000, ]
train_labels <- labels[10001:60000]
val_labels <- labels[1:10000]

3model <- get_mnist_model()
model |> compile(
  optimizer = "adam",
  loss = "sparse_categorical_crossentropy",
  metrics = "accuracy"
)
4model |> fit(
  train_images, train_labels,
  epochs = 3,
  validation_data = list(val_images, val_labels)
)
5test_metrics <- model |> evaluate(test_images, test_labels)
6predictions <- model |> predict(test_images)
1
Creates a model. (We factor this into a separate function so we can reuse it later.)
2
Loads our data, reserving some for validation
3
Compiles the model by specifying its optimizer, the loss function to minimize, and metrics to monitor
4
Uses fit() to train the model, optionally providing validation data to monitor performance on unseen data
5
Uses evaluate() to compute the loss and metrics on new data
6
Uses predict() to compute classification probabilities on new data

There are a couple of ways to customize this simple workflow:

  • By providing custom metrics
  • By passing callbacks to the fit() method to schedule actions to be taken at specific points during training

Let’s take a look at these.

7.3.1 Writing our own metrics

Metrics are key to measuring the performance of a model—in particular, to measure the difference between its performance on the training data and its performance on the test data. Commonly used metrics for classification and regression are already part of the built-in metrics family of function; most of the time, those are what you will use. But if you’re doing anything out of the ordinary, you will need to be able to write your own metrics. It’s simple!

A Keras metric is a subclass of the Keras Metric class. Similarly to layers, a metric has an internal state stored in Keras variables. Unlike layers, these variables aren’t updated via backpropagation, so we have to write the state update logic—which happens in the update_state() method. For example, here’s a simple custom metric that measures the root mean squared error (RMSE) with sparse (integer) labels.

Listing 7.18: Implementing a custom metric by subclassing the Metric class
1metric_sparse_root_mean_squared_error <- new_metric_class(
  classname = "SparseRootMeanSquaredError",

2  initialize = function(name = "rmse", ...) {
    super$initialize(name = name, ...)
    self$sum_sq_error <- self$add_weight(
      name = "sum_sq_error", initializer = "zeros"
    )
    self$total_samples <- self$add_weight(
      name = "total_samples", initializer = "zeros"
    )
  },

3  update_state = function(y_true, y_pred, sample_weight = NULL) {
    .[num_samples, num_classes] <- op_shape(y_pred)
    y_true <- op_one_hot(
      y_true,
      zero_indexed = TRUE,
      num_classes = num_classes
    )
    sse <- op_sum(op_square(y_true - y_pred))
    self$sum_sq_error$assign_add(sse)
    self$total_samples$assign_add(num_samples)
  },
1
Subclasses the Metric class
2
Defines the state variables in the constructor. Here we track the running sum of squared errors and the number of samples.
3
Implements the state update logic in update_state() for integer labels with categorical predictions. The y_true argument is the targets (or labels) for one batch, and y_pred represents the corresponding predictions from the model. To match our MNIST model, we expect categorical predictions and integer labels. Ignore the sample_weight argument; we won’t use it here.

We use the result() method to return the current value of the metric:

  result = function() {
    op_sqrt(op_divide_no_nan(
      self$sum_sq_error,
      self$total_samples
    ))
  },

Meanwhile, we also need to expose a way to reset the metric state without having to re-instantiate it: this enables the same metric objects to be used across different epochs of training or across both training and evaluation. We do this in the reset_state() method:

  reset_state = function() {
    self$sum_sq_error$assign(0)
    self$total_samples$assign(0)
  }
)

Custom metrics can be used just like built-in ones. Let’s test-drive our metric:

model <- get_mnist_model()
model |> compile(
  optimizer = "adam",
  loss = "sparse_categorical_crossentropy",
  metrics = list("accuracy", metric_sparse_root_mean_squared_error())
)
model |> fit(
  train_images, train_labels,
  epochs = 3,
  validation_data = list(val_images, val_labels)
)
test_metrics <- model |> evaluate(test_images, test_labels)

We can now see the fit() progress bar display the RMSE of our model.

7.3.2 Using callbacks

Launching a training run on a large dataset for tens of epochs using fit() can be a bit like launching a paper airplane: past the initial impulse, we don’t have any control over its trajectory or its landing spot. If we want to avoid bad outcomes (and thus wasted paper airplanes), it’s smarter to use not a paper plane but a drone that can sense its environment, send data back to its operator, and automatically make steering decisions based on its current state. The Keras callbacks API will help you transform your call to fit() from a paper airplane into a smart, autonomous drone that can self-introspect and dynamically take action.

A callback is an object (a class instance implementing specific methods) that is passed to the model in the call to fit() and that is called by the model at various points during training. It has access to all the available data about the state of the model and its performance, and it can take action: interrupt training, save a model, load a different weight set, or otherwise alter the state of the model.

Here are some examples of ways we can use callbacks:

  • Model checkpointing—Saving the current state of the model at different points during training.
  • Early stopping—Interrupting training when the validation loss is no longer improving (and of course, saving the best model obtained during training).
  • Dynamically adjusting the value of certain parameters during training—Such as the learning rate of the optimizer.
  • Logging training and validation metrics during training, or visualizing the representations learned by the model as they’re updated—The fit() progress bar you’re familiar with is in fact a callback!

The callbacks function family includes a number of built-in callbacks (this is not an exhaustive list):

callback_model_checkpoint()
callback_early_stopping()
callback_learning_rate_scheduler()
callback_reduce_lr_on_plateau()
callback_csv_logger()

Let’s review two of them to give you an idea of how to use them: callback_early_stopping and callback_model_checkpoint.

7.3.2.1 The EarlyStopping and ModelCheckpoint callbacks

When we’re training a model, there are many things we can’t predict from the start. In particular, we can’t tell how many epochs will be needed to get to an optimal validation loss. Our examples so far have adopted the strategy of training for enough epochs that we begin overfitting, using the first run to figure out the optimal number of epochs, and then finally launching a new training run from scratch using this optimal number. Of course, this approach is wasteful. A much better way to handle this is to stop training when we measure that the validation loss is no longer improving. This can be achieved using the EarlyStopping callback.

The EarlyStopping callback interrupts training once a target metric being monitored has stopped improving for a fixed number of epochs. For instance, this callback allows us to interrupt training as soon as we start overfitting, thus avoiding having to retrain our model for a smaller number of epochs. This callback is typically used in combination with ModelCheckpoint, which lets us continually save the model during training (and, optionally, save only the current best model so far: the version of the model that achieved the best performance at the end of an epoch).

Listing 7.19: Using the callbacks argument in the fit() method
1callbacks_list <- list(
2  callback_early_stopping(
3    monitor = "val_accuracy",
4    patience = 1
  ),
5  callback_model_checkpoint(
6    filepath = "checkpoint_path.keras",
7    monitor = "val_loss",
    save_best_only = TRUE
  )
)

model <- get_mnist_model()
model |> compile(
  optimizer = "adam",
  loss = "sparse_categorical_crossentropy",
8  metrics = "accuracy"
)
9model |> fit(
  train_images, train_labels,
  epochs = 10,
  callbacks = callbacks_list,
  validation_data = list(val_images, val_labels)
)
1
Callbacks are passed to the model via the callbacks argument in fit(), which takes a list of callbacks. We can pass any number of callbacks.
2
Interrupts training when improvement stops
3
Monitors the model’s validation accuracy
4
Interrupts training when accuracy has stopped improving for more than one epoch (that is, two epochs)
5
Saves the current weights after every epoch
6
Path to the destination model file
7
These two arguments mean we won’t overwrite the model file unless val_loss has improved, which allows us to keep the best model seen during training.
8
We monitor accuracy, so it should be part of the model’s metrics.
9
Because the callback will monitor validation loss and validation accuracy, we need to pass validation_data to the call to fit().

Note that we can always save models manually after training as well—just call save_model(model, "checkpoint_path.keras"). To reload the model we’ve saved, we use

model <- load_model("checkpoint_path.keras")

7.3.3 Writing our own callbacks

If we need to take a specific action during training that isn’t covered by one of the built-in callbacks, we can write our own callback. Callbacks are implemented by subclassing the class Callback. We can then implement any number of the following transparently named methods, which are called at various points during training:

1on_epoch_begin(epoch, logs)
2on_epoch_end(epoch, logs)
3on_batch_begin(batch, logs)
4on_batch_end(batch, logs)
5on_train_begin(logs)
6on_train_end(logs)
1
Called at the start of every epoch
2
Called at the end of every epoch
3
Called right before processing each batch
4
Called right after processing each batch
5
Called at the start of training
6
Called at the end of training

These methods are all called with a logs argument, which is a named list containing information about the previous batch, epoch, or training run: training and validation metrics, and so on. The on_epoch_* and on_batch_* methods also take the epoch or batch index as the first argument (an integer).

Here’s a simple example callback that saves a list of per-batch loss values during training and plots these values at the end of each epoch.

Listing 7.20: Creating a custom callback by subclassing the Callback class
callback_plot_per_batch_loss_history <- new_callback_class(
  classname = "PlotPerBatchLossHistory",

  initialize = function(file = "training_loss.pdf") {
    private$outfile <- file
  },

  on_train_begin = function(logs = NULL) {
    private$plots_dir <- tempfile()
    dir.create(private$plots_dir)
    private$per_batch_losses <-
1      reticulate::py_eval("[]", convert = FALSE)
  },

  on_epoch_begin = function(epoch, logs = NULL) {
    private$per_batch_losses$clear()
  },

  on_batch_end = function(batch, logs = NULL) {
    private$per_batch_losses$append(logs$loss)
  },

  on_epoch_end = function(epoch, logs = NULL) {
    losses <- as.numeric(reticulate::py_to_r(private$per_batch_losses))

    filename <- sprintf("epoch_%04i.pdf", epoch)
    filepath <- file.path(private$plots_dir, filename)

    pdf(filepath, width = 7, height = 5)
    on.exit(dev.off())

    plot(losses, type = "o",
         ylim = c(0, max(losses)),
         panel.first = grid(),
         main = sprintf("Training Loss for Each Batch\n(Epoch %i)", epoch),
         xlab = "Batch", ylab = "Loss")
  },

  on_train_end = function(logs) {
    private$per_batch_losses <- NULL
    plots <- sort(list.files(private$plots_dir, full.names = TRUE))
    qpdf::pdf_combine(plots, private$outfile)
    unlink(private$plots_dir, recursive = TRUE)
  }
)
1
Uses a Python list as a stateful stack that we can efficiently modify in place

Let’s test-drive it:

model <- get_mnist_model()
model |> compile(
  optimizer = "adam",
  loss = "sparse_categorical_crossentropy",
  metrics = "accuracy"
)
model |> fit(
  train_images, train_labels,
  epochs = 10,
  callbacks = list(callback_plot_per_batch_loss_history()),
  validation_data = list(val_images, val_labels)
)

We get plots that look like figure 7.5.

The output of our custom history plotting callback

The output of our custom history plotting callback

7.3.4 Monitoring and visualization with TensorBoard

To do good research and develop good models, we need rich, frequent feedback about what’s going on inside our models during our experiments. That’s the point of running experiments: to get information about how well a model performs—as much information as possible. Making progress is an iterative process: we start with an idea and express it as an experiment, attempting to validate or invalidate it. We run this experiment and process the information it generates, as shown in figure 7.6. This inspires our next idea. The more iterations of this loop we can run, the more refined and powerful our ideas become. Keras helps us go from idea to experiment in the least possible time, and fast GPUs can help us get from experiment to result as quickly as possible. But what about processing the experiment results? That’s where TensorBoard comes in.

The loop of progress

The loop of progress

TensorBoard is a browser-based application that we can run locally. It’s the best way to monitor everything that goes on inside our model during training. With TensorBoard, we can

  • Visually monitor metrics during training
  • Visualize our model architecture
  • Visualize histograms of activations and gradients
  • Explore embeddings in 3D

If we’re monitoring more information than just the model’s final loss, we can develop a clearer vision of what the model does and doesn’t do, and we can make progress more quickly.

The easiest way to use TensorBoard with a Keras model and the fit() method is with callback_tensorboard(). In the simplest case, we just specify where we want the callback to write logs, and we’re good to go:

model <- get_mnist_model()
model |> compile(
  optimizer = "adam",
  loss = "sparse_categorical_crossentropy",
  metrics = "accuracy"
)

model |> fit(
  train_images, train_labels,
  epochs = 10,
  validation_data = list(val_images, val_labels),
  callbacks = c(
    callback_tensorboard(
      log_dir = "./full_path_to_your_log_dir"
    )
  )
)

Once the model starts running, it will write logs at the target location. If you are running your R script on a local machine, you can launch the local TensorBoard server using the following function:

# Load TensorBoard in R
tensorboard(log_dir = "./full_path_to_your_log_dir")

In the TensorBoard interface, you can monitor live graphs of your training and evaluation metrics, as shown in figure 7.7.

TensorBoard can be used for easy monitoring of training and evaluation metrics.

TensorBoard can be used for easy monitoring of training and evaluation metrics.

7.4 Writing training and evaluation loops

The fit() workflow strikes a nice balance between ease of use and flexibility. It’s what you will use most of the time. However, it isn’t meant to support everything a deep learning researcher may want to do—even with custom metrics, custom losses, and custom callbacks.

After all, the built-in fit() workflow is solely focused on supervised learning: a setup where there are known targets (also called labels or annotations) associated with input data and where we compute our loss as a function of these targets and the model’s predictions. However, not every form of machine learning falls into this category. There are other setups where no explicit targets are present, such as generative learning (which we will introduce in chapter 16), self-supervised learning (where targets are obtained from the inputs), and reinforcement learning (where learning is driven by occasional “rewards”—much like training a dog). And even if you’re doing regular supervised learning, as a researcher, you may want to add some novel bells and whistles that require low-level flexibility.

Whenever you find yourself in a situation where the built-in fit() is not enough, you will need to write your own custom training logic. You’ve already seen simple examples of low-level training loops in chapters 2 and 3. As a reminder, the contents of a typical training loop look like this:

  1. Run the “forward pass” (compute the model’s output) to obtain a loss value for the current batch of data.
  2. Retrieve the gradients of the loss with regard to the model’s weights.
  3. Update the model’s weights to lower the loss value on the current batch of data.

These steps are repeated for as many batches as necessary. This is essentially what fit() does under the hood. In this section, you will learn to reimplement fit() from scratch, which will give you all the knowledge you need to write any training algorithm you may come up with.

Let’s go over the details. Throughout the next few sections, you’ll work your way up to writing a fully featured custom training loop in TensorFlow, PyTorch, and JAX.

7.4.1 Training vs. inference

In the low-level training loop examples you’ve seen so far, step 1 (the forward pass) was done via predictions = model(inputs), and step 2 (retrieving the gradients computed by the gradient tape) was done via a backend-specific API, such as

  • gradients <- tape$gradient(loss, model$weights) in TensorFlow
  • loss$backward() in PyTorch
  • jax$value_and_grad() in JAX

In the general case, we need to take two subtleties into account. Some Keras layers, such as the Dropout layer, have different behaviors during training and during inference (when we use them to generate predictions). Such layers expose a training boolean argument in their call() method. Calling dropout(inputs, training=TRUE) will drop some activation entries, and calling dropout(inputs, training=FALSE) does nothing. By extension, Functional models and Sequential models also expose this training argument in their call() methods. Remember to pass training=TRUE when you call a Keras model during the forward pass! Our forward pass thus becomes predictions = model(inputs, training=TRUE).

In addition, note that when we retrieve the gradients of the weights of our model, we should not use model$weights, but rather model$trainable_weights. Indeed, layers and models own two kinds of weights:

  • Trainable weights, meant to be updated via backpropagation to minimize the loss of the model, such as the kernel and bias of a Dense layer
  • Nontrainable weights, which are meant to be updated during the forward pass by the layers that own them. For instance, if we wanted a custom layer to keep a counter of how many batches it has processed so far, that information would be stored in a nontrainable weight, and at each batch, our layer would increment the counter by one.

Among Keras built-in layers, the only layer that features nontrainable weights is the BatchNormalization layer, which we will introduce in chapter 9. The BatchNormalization layer needs nontrainable weights to track information about the mean and standard deviation of the data that passes through it, to perform an online approximation of feature normalization (a concept you learned about in chapters 4 and 6).

7.4.2 Writing custom training step functions

Taking into account these two details, a supervised learning training step ends up looking like this, in pseudo-code:

train_step <- function(inputs, targets) {
1  predictions <-  model(inputs, training = TRUE)
2  loss <- loss_fn(targets, predictions)
3  gradients <- get_gradients_of(loss, wrt = model$trainable_weights)
4  optimizer$apply(gradients, model$trainable_weights)
}
1
Runs the forward pass
2
Computes the loss for the current batch
3
Retrieves the gradients of the loss with regard to the model’s trainable weights. This function doesn’t actually exist!
4
Updates the model’s trainable weights based on the gradients

This snippet is pseudo-code rather than real code because it includes an imaginary function, get_gradients_of(). In reality, retrieving gradients is done in a way that is specific to our current backend: JAX, TensorFlow, or PyTorch.

Let’s use what you learned about each framework in chapter 3 to implement a real version of this train_step() function. We’ll start with TensorFlow and PyTorch because these two make the job relatively easy. We’ll end with JAX, which is quite a bit more complex.

7.4.2.1 A TensorFlow training step function

TensorFlow lets us write code that looks pretty much like our pseudo-code snippet. The only difference is that the forward pass should take place inside a GradientTape scope. We can then use the tape object to retrieve the gradients:

library(tensorflow, exclude = c("set_random_seed", "shape"))
library(keras3)

use_backend("tensorflow")
get_mnist_model <- function() {
  inputs <- keras_input(shape = c(28 * 28))
  outputs <- inputs |>
    layer_dense(512, activation = "relu") |>
    layer_dropout(0.5) |>
    layer_dense(10, activation = "softmax")
  keras_model(inputs, outputs)
}
model <- get_mnist_model()
loss_fn <- loss_sparse_categorical_crossentropy()
optimizer <- optimizer_adam()

train_step <- function(inputs, targets) {
1  with(tf$GradientTape() %as% tape, {
2    predictions <- model(inputs, training = TRUE)
    loss <- loss_fn(targets, predictions)
  })
3  gradients <- tape$gradient(loss, model$trainable_weights)
4  optimizer$apply(gradients, model$trainable_weights)
  loss
}
1
Opens a GradientTape
2
Runs the forward pass
3
Retrieves the gradients from the tape
4
Updates the model’s trainable weights based on the gradients

Let’s run it for a single step:

batch_size <- 32
inputs <- train_images[1:batch_size, ]
targets <- train_labels[1:batch_size]
loss <- train_step(inputs, targets)
loss
tf.Tensor(2.5361915, shape=(), dtype=float32)

Easy enough! Let’s do PyTorch next.

7.4.2.2 A PyTorch training step function

When we use the PyTorch backend, all of our Keras layers and models inherit from the PyTorch torch$nn$Module class and expose the native Module API. As a result, our model, its trainable weights, and our loss tensor are all aware of each other and interact via three methods: loss$backward(), weight$value$grad, and model$zero_grad().

As a reminder from chapter 3, the mental model to keep in mind is as follows:

  • With each forward pass, PyTorch builds up a one-time computation graph that keeps track of the computation that just happened.
  • Calling $backward() on any given scalar node of this graph (like our loss) will run the graph backward starting from that node, automatically populating a tensor$grad attribute on all tensors involved (if they satisfy requires_grad=TRUE) containing the gradient of the output node with respect to that tensor. In particular, it will populate the grad attribute of our trainable parameters.
  • To clear the contents of that tensor$grad attribute, we should call tensor$grad <- NULL on all our tensors. Because it would be a bit cumbersome to do this on all model variables individually, we can just do it at the model level via model$zero_grad()—the zero_grad() call will propagate to all variables tracked by the model. Clearing gradients is critical because calls to backward() are additive: if we don’t clear the gradients at each step, the gradient values would accumulate, and training would not proceed.

Let’s chain these steps:

library(reticulate)
library(keras3)

use_backend("torch")
torch <- import("torch")
loss_fn <- loss_sparse_categorical_crossentropy()
optimizer <- optimizer_adam()
model <- get_mnist_model()

train_step <- function(inputs, targets) {
1  predictions <- model(inputs, training = TRUE)
  loss <- loss_fn(targets, predictions)
2  loss$backward()
  gradients <- model$trainable_weights |>
3    lapply(\(weight) weight$value$grad)
  with(torch$no_grad(), {
    optimizer$apply(gradients, model$trainable_weights)
4  })
5  model$zero_grad()
  loss
}
1
Runs the forward pass
2
Runs the backward pass, populating gradient values
3
Recovers the gradient associated with each trainable variable. Note that weight$value is the PyTorch tensor that contains the variable’s value.
4
Updates the model’s trainable weights based on the gradients. This must be done in a no_grad() scope.
5
Don’t forget to clear the gradients!

Let’s run it for a single step:

batch_size <- 32
inputs <- train_images[1:batch_size, ]
targets <- train_labels[1:batch_size]
loss <- train_step(inputs, targets)
loss
tensor(2.3919, device='cuda:0', grad_fn=<WhereBackward0>)

That wasn’t too difficult! Now, let’s move on to JAX.

7.4.2.3 A JAX training step function

When it comes to low-level training code, JAX tends to be the most complex of the three backends, because of its fully stateless nature. Statelessness makes JAX highly performant and scalable by enabling compilation and automatic performance optimizations. However, writing stateless code requires us to jump through some hoops.

Because the gradient function is obtained via metaprogramming, we first need to define the function that returns our loss. Further, this function needs to be stateless, so it needs to take as arguments all the variables it’s going to be using, and it needs to return the value of any variable it has updated. Remember those nontrainable weights that can get modified during the forward pass? Those are the variables we need to return.

To make it easier to work with the stateless programming paradigm of JAX, Keras models have a stateless forward-pass method: the stateless_call() method. It behaves just like call(), except that

  • It takes as input the model’s trainable weights and nontrainable weights, in addition to the inputs and training arguments.
  • It returns the model’s updated nontrainable weights, in addition to the model’s outputs.

It works like this:

.[outputs, non_trainable_weights] <- model$stateless_call(
    inputs, trainable_weights, non_trainable_weights
)

We can use stateless_call() to implement our JAX loss function. Because the loss function also computes updates for all nontrainable variables, we name it compute_loss_and_updates():

library(reticulate)
library(keras3)

use_backend("jax")
jax <- import("jax")
model <- get_mnist_model()
model$build(shape(28 * 28))
loss_fn <- loss_sparse_categorical_crossentropy()

compute_loss_and_updates <- function(
1  trainable_variables,
  non_trainable_variables,
  inputs,
  targets
) {
  .[outputs, non_trainable_variables] <-
2    model$stateless_call(
      trainable_variables,
      non_trainable_variables,
      inputs,
      training = TRUE
    )
  loss <- loss_fn(targets, outputs)
3  list(loss, non_trainable_variables)
}
1
Gradients are computed for the entries in the first argument (trainable_variables here)
2
Calls stateless_call()
3
Returns the scalar loss value and the updated nontrainable weights

Once we have this compute_loss_and_updates() function, we can pass it to jax$value_and_grad() to obtain the gradient-computation:

grad_fn <- jax$value_and_grad(fn)
.[loss, gradients] <- grad_fn(...)

Now, there’s just a small problem. Both jax$grad() and jax$value_and_grad() require fn to return a scalar value only. Our compute_loss_and_updates() function returns a scalar value as its first output, but it also returns the new value for the nontrainable weights. Remember what you learned in chapter 3? The solution is to pass a has_aux argument to grad() or value_and_grad(), like this:

grad_fn <- jax$value_and_grad(compute_loss_and_updates, has_aux = TRUE)

We use it like this:

.[.[loss, non_trainable_weights], gradients] <- grad_fn(
    trainable_variables, non_trainable_variables, inputs, targets
)

Ok, that was a lot of JAXiness. But now we’ve got almost everything to assemble our JAX training step. We just need the last piece of the puzzle: optimizer$apply().

When we wrote our first basic training step in TensorFlow at the beginning of chapter 2, we wrote an update step function that looked like this:

learning_rate <- 1e-3

update_weights <- function(gradients, weights) {
  Map(\(w, g) w$assign(w - g * learning_rate),
      weights, gradients)
}

This corresponds to what Keras’s optimizer_sgd() would do. However, every other optimizer in the Keras API is somewhat more complex than that and keeps track of auxiliary variables that help speed up training—in particular, most optimizers use some form of momentum, which you learned about in chapter 2. These extra variables are updated at each step of training, and in the JAX world, that means we need to get our hands on a stateless function that takes these variables as arguments and returns their new value.

To make this easy, Keras makes available the stateless_apply() method on all optimizers. It works like this:

.[trainable_variables, optimizer_variables] <- optimizer$stateless_apply(
    optimizer_variables, grads, trainable_variables
)

Now we have enough to assemble an end-to-end training step:

optimizer <- optimizer_adam()
optimizer$build(model$trainable_variables)

1train_step <- function(state, inputs, targets) {
  .[trainable_variables, non_trainable_variables, optimizer_variables] <-
2    state

3  .[.[loss, non_trainable_variables], grads] <- grad_fn(
    trainable_variables, non_trainable_variables,
    inputs, targets
  )
4  .[trainable_variables, optimizer_variables] <- optimizer$stateless_apply(
    optimizer_variables, grads, trainable_variables
  )
5  new_state <- list(
    trainable_variables,
    non_trainable_variables,
    optimizer_variables
  )
  list(loss, new_state)
}
1
The state is part of the function arguments.
2
Unpacks the state
3
Computes gradients and updates to nontrainable variables
4
Updates trainable variables and optimizer variables
5
Returns the updated state alongside the loss

Let’s run it for a single step:

batch_size <- 32
inputs <- train_images[1:batch_size, ]
targets <- train_labels[1:batch_size]

trainable_variables <- model$trainable_variables |> lapply(\(w) w$value)
non_trainable_variables <-
  model$non_trainable_variables |> lapply(\(w) w$value)
optimizer_variables <- optimizer$variables |> lapply(\(w) w$value)

state <- list(trainable_variables, non_trainable_variables, optimizer_variables)
.[loss, state] <- train_step(state, inputs, targets)
loss
Array(2.393485, dtype=float32)

It’s definitely a bit more work than TensorFlow and PyTorch, but the speed and scalability benefits of JAX more than make up for it. Next, let’s take a look at another important element of a custom training loop: metrics.

7.4.3 Low-level usage of metrics

In a low-level training loop, you will probably want to use Keras metrics (custom ones or the built-in ones). You’ve already learned about the metrics API: simply call update_state(y_true, y_pred) for each batch of targets and predictions and then use result() to query the current metric value:

metric <- metric_sparse_categorical_accuracy()
targets <- op_array(c(0, 1, 2), dtype = "int32")
predictions <- op_array(rbind(c(1, 0, 0), c(0, 1, 0), c(0, 0, 1)))
metric$update_state(targets, predictions)
current_result <- metric$result()
cat(sprintf("result: %.2f\n", current_result))
result: 1.00

You may also need to track the average of a scalar value, such as the model’s loss. You can do this via metric_mean():

mean_tracker <- metric_mean()
for(value in 0:4) {
  value <- op_array(value)
  mean_tracker$update_state(value)
}
cat(sprintf("Mean of values: %.2f\n", mean_tracker$result()))
Mean of values: 2.00

Remember to use metric$reset_state() when you want to reset the current results (at the start of a training epoch or at the start of evaluation).

Now, if you’re using JAX, state-modifying methods like update_state() or reset_state() can’t be used inside a stateless function. Instead, you can use the stateless metrics API, which is similar to the model$stateless_call() and optimizer$stateless_apply() methods you’ve already learned about. Here’s how it works:

metric <- metric_sparse_categorical_accuracy()
targets <- op_array(c(0, 1, 2), dtype = "int32")
predictions <-  op_array(rbind(c(1, 0, 0), c(0, 1, 0), c(0, 0, 1)))

1metric_variables <- metric$variables
2metric_variables <- metric$stateless_update_state(
  metric_variables, targets, predictions
)
3current_result <- metric$stateless_result(metric_variables)
cat(sprintf("result: %.2f\n", current_result))

4metric_variables <- metric$stateless_reset_state()
1
Gets the metric’s state variables
2
Gets updated values for the metric’s state
3
Computes the metric value corresponding to the current state
4
Gets blank variable values for the metric
result: 1.00

7.4.4 Using fit() with a custom training loop

In the previous sections, we wrote our own training logic entirely from scratch. Doing so provides the most flexibility, but we end up writing a lot of code while simultaneously missing out on many convenient features of fit(), such as callbacks, performance optimizations, and built-in support for distributed training.

What if we need a custom training algorithm, but we still want to use the power of the built-in Keras training loop? There’s actually a middle ground between fit() and a training loop written from scratch: we can provide a custom training step function and let the framework do the rest.

We can do this by overriding the train_step() method of the Model class. This is the function that is called by fit() for every batch of data. We can then call fit() as usual—and it will run our learning algorithm under the hood.

Here’s how it works:

  1. Create a new class that subclasses Model.
  2. Override the train_step() method. Its contents are nearly identical to what we used in the previous section.
  3. Return a named list mapping metric names (including the loss) to their current value.

Note the following:

  • This pattern does not prevent you from building models with the Functional API. You can do this whether you’re building Sequential models, Functional API models, or subclassed models.
  • You don’t need to use a tf_function() or jax$jit() when you override train_step()—the framework does it for you.

7.4.4.1 Customizing fit() with TensorFlow

Let’s start by coding a custom TensorFlow train step:

loss_fn <- loss_sparse_categorical_crossentropy()
1loss_tracker <- metric_mean(name="loss")

CustomModel <- new_model_class(
  "CustomModel",

2  train_step = function(data) {
    .[inputs, targets] <- data
    with(tf$GradientTape() %as% tape, {
3      predictions <- self(inputs, training = TRUE)
      loss <- loss_fn(targets, predictions)
    })
    gradients <- tape$gradient(loss, model$trainable_weights)
    self$optimizer$apply(gradients, model$trainable_weights)

4    loss_tracker$update_state(loss)
5    list("loss" = loss_tracker$result())
  },

  metrics = active_property(\() {
6    list(loss_tracker)
  })
)
1
Metric object to track the average of per-batch losses during training and evaluation
2
Overrides the train_step() method
3
We use self(inputs, training=TRUE) instead of model(inputs, training=TRUE) because our model is the class itself.
4
Updates the loss tracker metric that tracks the average of the loss
5
Returns the average loss so far by querying the loss tracker metric
6
Lists the loss tracker metric so the model can automatically call reset_state() on it at the start of each epoch and at the start of a call to evaluate()—so we don’t have to do it by hand. Any metric you want to reset across epochs should be listed here.

We can now instantiate our custom model, compile it (we only pass the optimizer, because the loss is already defined outside of the model), and train it using fit() as usual. Let’s put the model definition in its own reusable function:

get_custom_model <- function() {
  inputs <- keras_input(shape = (28 * 28))
  outputs <- inputs |>
    layer_dense(512, activation = "relu") |>
    layer_dropout(0.5) |>
    layer_dense(10, activation = "softmax")
  model <- CustomModel(inputs, outputs)
  model |> compile(optimizer = optimizer_adam())
  model
}

Let’s give it a whirl:

model <- get_custom_model()
model |> fit(train_images, train_labels, epochs=3)

7.4.4.2 Customizing fit() with PyTorch

Next, the PyTorch version:

loss_fn <- loss_sparse_categorical_crossentropy()
loss_tracker <- metric_mean(name = "loss")

CustomModel <- new_model_class(
  "CustomModel",

  train_step = function(data) {
    .[inputs, targets] <- data
1    predictions <- self(inputs, training = TRUE)
    loss <- loss_fn(targets, predictions)

2    loss$backward()
    trainable_weights <- self$trainable_weights
    gradients <- trainable_weights |> lapply(\(v) v$value$grad)

    with(torch$no_grad(), {
3      self$optimizer$apply(gradients, trainable_weights)
    })

4    loss_tracker$update_state(loss)
5    list(loss = loss_tracker$result())
  },

  metrics = active_property(\() {
    list(loss_tracker)
  })
)
1
Runs the forward pass
2
Retrieves the gradients
3
Updates weights
4
Updates loss tracker metric
5
Returns the average loss so far by querying the loss tracker metric

Let’s try it:

model <- get_custom_model()
model |> fit(train_images, train_labels, epochs=3)

7.4.4.3 Customizing fit() with JAX

Finally, let’s write the JAX version. First we need to define a compute_loss_and_updates() method, similar to the compute_loss_and_updates() function we used in our custom training step example:

loss_fn <- loss_sparse_categorical_crossentropy()

CustomModel <- new_model_class(
  "CustomModel",

  compute_loss_and_updates = function(
    trainable_variables,
    non_trainable_variables,
    inputs,
    targets,
    training = FALSE
  ) {
    .[predictions, non_trainable_variables] <- self$stateless_call(
      trainable_variables,
      non_trainable_variables,
      inputs,
      training = training
    )
    loss <- loss_fn(targets, predictions)
1    list(loss, non_trainable_variables)
  },
1
Returns both the loss and the updated nontrainable variables

Note that we aren’t computing a moving average of the loss as we did for the other two backends. Instead, we just return the per-batch loss value, which is less useful. We do this to simplify metric state management in the example: the code would get very verbose if we included it (you will learn about metric management in the next section):

1  train_step = function(state, data) {
    .[
      trainable_variables,
      non_trainable_variables,
      optimizer_variables,
      metrics_variables
    ] <- state
    .[inputs, targets] <- data

2    grad_fn <- jax$value_and_grad(
      self$compute_loss_and_updates,
      has_aux = TRUE
    )

3    .[.[loss, non_trainable_variables], grads] <- grad_fn(
      trainable_variables,
      non_trainable_variables,
      inputs,
      targets,
      training = TRUE
    )

4    .[trainable_variables, optimizer_variables] <-
      self$optimizer$stateless_apply(
        optimizer_variables,
        grads,
        trainable_variables
      )

5    logs <- list(loss = loss)
    new_state <- list(
      trainable_variables,
      non_trainable_variables,
      optimizer_variables,
      metrics_variables
    )
6    list(logs, new_state)
  }
)
1
Unpacks the state. metrics_variables are part of it, although we won’t use them here.
2
Gets the gradient function
3
Computes gradients and updates to nontrainable variables
4
Updates trainable variables and optimizer variables
5
We aren’t computing a moving average of the loss; instead, we return the per-batch value.
6
Returns metric logs and updated state variables

Let’s try it out:

model <- get_custom_model()
model |> fit(train_images, train_labels, epochs=3)

7.4.5 Handling metrics in a custom train_step()

Finally, what about the loss and metrics that we can pass to compile()? After we’ve called compile(), we get access to

  • self$compute_loss()—Combines the loss function we passed to compile() with regularization losses that may be added by certain layers.
  • self$metrics—The list of metrics we passed to compile(). Note that it also includes a metric that tracks the loss.

7.4.5.1 train_step() metrics handling with TensorFlow

Here’s what it looks like with TensorFlow:

CustomModel <- new_model_class(
  "CustomModel",
  train_step = function(data) {
    .[inputs, targets] <- data
    with(tf$GradientTape() %as% tape, {
      predictions <- self(inputs, training = TRUE)
1      loss <- self$compute_loss(y=targets, y_pred=predictions)
    })
    gradients <- tape$gradient(loss, self$trainable_weights)
    self$optimizer$apply(gradients, self$trainable_weights)

    logs <- list()
2    for (metric in self$metrics) {
      if (metric$name == "loss")
        metric$update_state(loss)
      else
        metric$update_state(targets, predictions)

      logs[[metric$name]] <- metric$result()
    }

3    logs
  }
)
1
Computes the loss via self$compute_loss
2
Updates the model’s metrics, including the one that tracks the loss
3
Returns a named list mapping metric names to their current value

Let’s try it:

get_custom_model <- function() {
  inputs <- keras_input(shape = (28 * 28))
  outputs <- inputs |>
    layer_dense(512, activation = "relu") |>
    layer_dropout(0.5) |>
    layer_dense(10, activation = "softmax")

  model <- CustomModel(inputs, outputs)
  model |> compile(
    optimizer = optimizer_adam(),
    loss = loss_sparse_categorical_crossentropy(),
    metrics = c(metric_sparse_categorical_accuracy())
  )
  model
}

model <- get_custom_model()
model |> fit(train_images, train_labels, epochs = 3)

7.4.5.2 train_step() metrics handling with PyTorch

And here’s what it looks like with PyTorch—it’s exactly the same code change!

CustomModel <- new_model_class(
  "CustomModel",

  train_step = function(data) {
    .[inputs, targets] <- data
    predictions <- self(inputs, training = TRUE)
    loss <- self$compute_loss(y = targets, y_pred = predictions)

    loss$backward()
    trainable_weights <- self$trainable_weights
    gradients <- trainable_weights |> lapply(\(v) v$value$grad)

    with(torch$no_grad(), {
      self$optimizer$apply(gradients, trainable_weights)
    })

    logs <- list()
    for (metric in self$metrics) {
      if (metric$name == "loss")
        metric$update_state(loss)
      else
        metric$update_state(targets, predictions)

      logs[[metric$name]] <- metric$result()
    }

    logs
  }
)

Let’s see how it runs:

model <- get_custom_model()
model |> fit(train_images, train_labels, epochs = 3)

7.4.5.3 train_step() metrics handling with JAX

Finally, here’s what it looks like with JAX. To start with, we can use compute_loss() in our compute_loss_and_updates() method to hit the loss passed to compile():

CustomModel <- new_model_class(
  "CustomModel",
  compute_loss_and_updates = function(
    trainable_variables,
    non_trainable_variables,
    inputs,
    targets,
    training = FALSE
  ) {
    .[predictions, non_trainable_variables] <-
      self$stateless_call(
        trainable_variables,
        non_trainable_variables,
        inputs,
        training = training
      )
    loss <- self$compute_loss(y = targets, y_pred = predictions)
    tuple(loss, tuple(predictions, non_trainable_variables))
  },

Next up: metric management. As usual, it’s a tad more complicated due to JAX’s statelessness requirements:

  train_step = function(self, state, data) {
    .[
      trainable_variables,
      non_trainable_variables,
      optimizer_variables,
      metrics_variables
1    ] <- state
    .[inputs, targets] <- data

    grad_fn <-
      jax$value_and_grad(self$compute_loss_and_updates, has_aux = TRUE)

    .[.[loss, .[predictions, non_trainable_variables]], grads] <-
      grad_fn(
        trainable_variables,
        non_trainable_variables,
        inputs,
        targets,
        training = TRUE
      )

    .[trainable_variables, optimizer_variables] <-
      self$optimizer$stateless_apply(
        optimizer_variables,
        grads,
        trainable_variables
      )

    logs <- new_metrics_vars <- list()
2    for (metric in self$metrics) {
3      this_metric_vars <- metrics_variables[seq(
        from = length(new_metrics_vars) + 1,
        along.with = metric$variables
      )]
4      this_metric_vars <- if (metric$name == "loss") {
        metric$stateless_update_state(this_metric_vars, loss)
      } else {
        metric$stateless_update_state(this_metric_vars, targets, predictions)
      }
      logs[[metric$name]] <- metric$stateless_result(this_metric_vars)
5      new_metrics_vars <- c(new_metrics_vars, this_metric_vars)
    }

    new_state <- list(
      trainable_variables,
      non_trainable_variables,
      optimizer_variables,
6      new_metrics_vars
    )
    list(logs, new_state)
  }
)
1
Metric variables are part of the state.
2
Iterates over metrics
3
Grabs the current metric’s variables
4
Updates the metric state
5
Appends updated metric variables
6
Returns new metric variables in the state

That was a lot of information, but by now you know enough to use Keras to do almost anything!

7.5 Summary

  • Keras offers a spectrum of different workflows, based on the principle of progressive disclosure of complexity. They all interoperate smoothly.
  • You can build models via the Sequential class, via the Functional API, or by subclassing the Model class. Most of the time, you’ll be using the Functional API.
  • The simplest way to train and evaluate a model is via the default fit() and evaluate() methods.
  • Keras callbacks provide a simple way to monitor models during your call to fit() and automatically take action based on the state of the model.
  • You can also fully control what fit() does by overriding the train_step() method, using APIs from your backend of choice—JAX, TensorFlow, or PyTorch.
  • Beyond fit(), you can also write your own training loops entirely from scratch, in a backend-native way. This is useful for researchers implementing brand-new training algorithms.