1x <- scale(x)- 1
- Assuming x is a 2D data matrix of shape (samples, features)
This chapter covers
Our previous examples have assumed that we already had a labeled dataset to start from and that we could immediately start training a model. In the real world, this is often not the case. You don’t start from a dataset; you start from a problem.
Imagine that you’re launching your own machine learning consulting shop. You incorporate, you put up a fancy website, you notify your network. The projects start rolling in:
It would be very convenient if you could import the correct dataset from keras3::dataset_mydataset() and start fitting some deep learning models. Unfortunately, in the real world, you’ll have to start from scratch.
In this chapter, you’ll learn about the universal step-by-step blueprint that you can use to approach and solve any machine learning problem, like those previously listed. This template will bring together and consolidate everything you’ve learned in chapters 4 and 5 and give you the wider context that should anchor what you will learn in the next chapters.
The universal workflow of machine learning is broadly structured in three parts:
Let’s dive in.
You can’t do good work without a deep understanding of the context of what you’re doing. Why is your customer trying to solve this particular problem? What value will they derive from the solution? How will your model be used? How will it fit into your customer’s business processes? What kind of data is available or could be collected? What kind of machine learning task can be mapped to the business problem?
Framing a machine learning problem usually involves many detailed discussions with stakeholders. Here are the questions that should be on top of your mind.
What will your input data be? What are you trying to predict? You can only learn to predict something if you have available training data: for example, you can only learn to classify the sentiment of movie reviews if you have both movie reviews and sentiment annotations available. As such, data availability is usually the limiting factor at this stage. In many cases, you will have to resort to collecting and annotating new datasets yourself (which we cover in the next section).
What type of machine learning task are you facing? Is it binary classification? Multiclass classification? Scalar regression? Vector regression? Multiclass, multilabel classification? Image segmentation? Ranking? Something else, like clustering, generation, or reinforcement learning? In some cases, it may be that machine learning isn’t even the best way to make sense of your data, and you should use something else, such as plain old-school statistical analysis:
What do existing solutions look like? Perhaps your customer already has a hand-crafted algorithm that handles spam filtering or credit card fraud detection—with lots of nested if statements. Perhaps a human is currently in charge of manually handling the process: monitoring the conveyor belt at the cookie plant and manually removing the bad cookies, or crafting playlists of song recommendations to be sent out to users who liked a specific artist. You should be sure to understand what systems are already in place and how they work.
Are there particular constraints you will need to deal with? For example, you may find out that the app for which you’re building a spam detection system is strictly end-to-end encrypted, so the spam detection model will have to live on the end user’s phone and must be trained on an external dataset. Perhaps the cookie filtering model has such latency constraints that it will need to run on an embedded device at the factory rather than on a remote server. You should understand the full context in which your work will fit.
Once you’ve done your research, you should know what your inputs will be, what your targets will be, and what broad type of machine learning task the problem maps to. Be aware of the hypotheses you’re making at this stage:
Until you have a working model, these are merely hypotheses, waiting to be validated or invalidated. Not all problems can be solved with machine learning; just because you’ve assembled examples of inputs X and targets Y doesn’t mean X contains enough information to predict Y. For instance, if you’re trying to predict the movements of a stock on the stock market given its recent price history, you’re unlikely to succeed, because price history doesn’t contain much predictive information.
You may sometimes be offered ethically dubious projects, such as “building an AI that rates the trustworthiness of someone from a picture of their face.” First of all, the validity of the project is in doubt: it isn’t clear why trustworthiness would be reflected on someone’s face. Second, such a task opens the door to all kinds of ethical problems. Collecting a dataset for this task would amount to recording the biases and prejudices of the people who label the pictures. The models you would train on such data would be merely encoding these same biases into a black-box algorithm, which would give them a thin veneer of legitimacy. In a largely tech-illiterate society like ours, “The AI algorithm said this person cannot be trusted” strangely appears to carry more weight and objectivity than “John Smith said this person cannot be trusted”—despite the former being a learned approximation of the latter. Your model would be laundering and operationalizing at scale the worst aspects of human judgment, with negative effects on the lives of real people.
Technology is never neutral. If your work has any effect on the world, then this effect has a moral direction: technical choices are also ethical choices. Always be deliberate about the values you want your work to support.
Once you understand the nature of the task and you know what your inputs and targets are going to be, it’s time for data collection—the most arduous, time-consuming, and costly part of most machine learning projects:
You learned in chapter 5 that a model’s ability to generalize comes almost entirely from the properties of the data it is trained on: the number of data points you have, the reliability of your labels, and the quality of your features. A good dataset is an asset worthy of care and investment. If you get an extra 50 hours to spend on a project, chances are that the most effective way to allocate them is to collect more data, rather than search for incremental modeling improvements.
The point that data matters more than algorithms was most famously made in a 2009 paper by Google researchers titled “The Unreasonable Effectiveness of Data” (the title is a riff on the well-known 1960 book The Unreasonable Effectiveness of Mathematics in the Natural Sciences by Eugene Wigner). This was before deep learning was popular, but remarkably, the rise of deep learning has only made the importance of data greater.
If you’re doing supervised learning, then once you’ve collected inputs (such as images), you’re going to need annotations for them (such as tags for those images): the targets you will train your model to predict. Sometimes annotations can be retrieved automatically—for instance, in the case of the music recommendation task or the click-through rate prediction task. But often, you have to annotate your data by hand. This is a labor-heavy process.
Your data annotation process will determine the quality of your targets, which, in turn, determines the quality of your model. Carefully consider the options you have available:
Outsourcing can potentially save you time and money, but it takes away control. Using something like Mechanical Turk is likely to be inexpensive and to scale well, but your annotations may end up being noisy.
To pick the best option, consider the constraints you’re working with:
If you decide to label your data in-house, ask yourself what software you will use to record annotations. You may well need to develop that software yourself. Productive data annotation software will save you a lot of time, so it’s something worth investing in early in a project.
Machine learning models can only make sense of inputs that are similar to what they’ve seen before. As such, it’s critical that the data used for training should be representative of the production data. This concern should be the foundation of all of your data collection work.
Suppose you’re developing an app that lets users take pictures of a dish to find out its name. You train a model using pictures from an image-sharing social network that’s popular with foodies. Come deployment time, feedback from angry users starts rolling in: your app gets the answer wrong 8 times out of 10. What’s going on? Your accuracy on the test set was well over 90%! A quick look at user-uploaded data reveals that mobile picture uploads of random dishes from random restaurants taken with random smartphones look nothing like the professional-quality, well-lit, appetizing pictures you trained the model on: your training data wasn’t representative of the production data. That’s a cardinal sin—welcome to machine learning hell.
If possible, collect data directly from the environment where your model will be used. A movie review sentiment classification model should be used on new IMDb reviews, not on Yelp restaurant reviews or Twitter status updates. If you want to rate the sentiment of a tweet, start by collecting and annotating actual tweets—from a set of users similar to those you’re expecting in production. If it’s not possible to train on production data, then make sure you fully understand how your training and production data differ, and that you are actively correcting these differences.
A related phenomenon you should be aware of is concept drift. You’ll encounter concept drift in almost all real-world problems, especially those that deal with user-generated data. Concept drift occurs when the properties of the production data change over time, causing model accuracy to gradually decay. A music recommendation engine trained in the year 2013 may not be very effective today. Likewise, the IMDb dataset you worked with was collected in 2011, and a model trained on it would likely not perform as well on reviews from 2020 compared to reviews from 2012, as vocabulary, expressions, and movie genres evolve over time. Concept drift is particularly acute in adversarial contexts like credit card fraud detection, where fraud patterns change practically every day. Dealing with fast concept drift requires constant data collection, annotation, and model retraining.
Keep in mind that machine learning can only be used to memorize patterns that are present in your training data. You can only recognize what you’ve seen before. Using machine learning trained on past data to predict the future is making the assumption that the future will behave like the past. That often isn’t the case.
A particularly insidious and common case of non-representative data is sampling bias. Sampling bias occurs when your data collection process interacts with what you are trying to predict, resulting in biased measurements. A famous historical example occurred in the 1948 US presidential election. On election night, the Chicago Tribune printed the headline “DEWEY DEFEATS TRUMAN.” The next morning, Truman emerged as the winner. The editor of the Tribune had trusted the results of a phone survey—but phone users in 1948 were not a random, representative sample of the voting population. They were more likely to be richer, conservative, and to vote for Dewey, the Republican candidate.
Nowadays, every phone survey takes sampling bias into account. That doesn’t mean that sampling bias is a thing of the past in political polling—far from it. But unlike in 1948, pollsters are aware of it and take steps to correct it.
It’s bad practice to treat a dataset as a black box. Before you start training models, you should explore and visualize your data to gain insights about what makes it predictive—which will inform feature engineering—and screen for potential problems:
To control something, you need to be able to observe it. To achieve success on a project, you must first define what you mean by success. Accuracy? Precision and recall? Customer retention rate? Your metric for success will guide all of the technical choices you will make throughout the project. It should directly align with your higher-level goals, such as the business success of your customer.
For balanced classification problems, where every class is equally likely, accuracy and area under the receiver operating characteristic curve (ROC AUC) are common metrics. For class-imbalanced problems, ranking problems, or multilabel classification, you can use precision and recall or a metric that counts false positives, true positives, false negatives, and true negatives. And it isn’t uncommon to have to define your own custom metric by which to measure success. To get a sense of the diversity of machine learning success metrics and how they relate to different problem domains, it’s helpful to browse the data science competitions on Kaggle (https://kaggle.com); it showcases a wide range of problems and evaluation metrics.
Once you know how you will measure your progress, you can get started with model development. Most tutorials and research projects assume that this is the only step—skipping problem definition and dataset collection, which are assumed to be already done, and skipping model deployment and maintenance, which is assumed to be handled by someone else. In fact, model development is only one step in the machine learning workflow, and if you ask us, it’s not the most difficult. The hardest things in machine learning are framing problems and collecting, annotating, and cleaning data. So cheer up: what comes next will be easy in comparison!
As you’ve learned, deep learning models typically don’t ingest raw data. Data preprocessing aims to make the raw data more amenable to neural networks. This includes vectorization, normalization, and handling missing values. Many preprocessing techniques are domain specific (for example, specific to text data or image data); we’ll cover those in the following chapters as we encounter them in practical examples. For now, we’ll review the basics that are common to all data domains.
All inputs and targets in a neural network must typically be tensors of floating-point data (or, in specific cases, tensors of integers or strings). Whatever data you need to process—sound, images, text—you must first turn into tensors, a step called data vectorization. For instance, in the two previous text classification examples from chapter 4, we started from text represented as lists of integers (standing for sequences of words), and we used multi-hot encoding to turn them into a tensor of floating-point values. In the examples of classifying digits and predicting house prices, the data already came in vectorized form, so we were able to skip this step.
In the MNIST digit-classification example from chapter 2, we started from image data encoded as integers in the 0–255 range, encoding grayscale values. Before we fed this data into our network, we had to divide by 255 so we’d end up with floating-point values in the 0–1 range. Similarly, when predicting house prices, we started from features that took a variety of ranges: some features had small floating-point values, others had fairly large integer values. Before we fed this data into our network, we had to normalize each feature independently so that it had a standard deviation of 1 and a mean of 0.
In general, it isn’t safe to feed into a neural network data that takes relatively large values (for example, multidigit integers, which are much larger than the initial values taken by the weights of a network) or data that is heterogeneous (for example, data where one feature is in the range 0–1, and another is in the range 100–200). Doing so can trigger large gradient updates that will prevent the network from converging. To make learning easier for your network, your data should have the following characteristics:
Additionally, the following stricter normalization practice is common and can help, although it isn’t always necessary (for example, we didn’t do this in the digit-classification example):
This is easy to do with R arrays:
1x <- scale(x)You may sometimes have missing values in your data. For instance, in the house price example, the second feature was the median age of houses in the district. What if this feature weren’t available for all samples? We’d then have missing values in the training or test data.
You could just discard the feature entirely, but you don’t necessarily have to:
Note that if you’re expecting missing categorical features in the test data, but the network was trained on data without any missing values, the network won’t have learned to ignore missing values! In this situation, you should artificially generate training samples with missing entries: copy some training samples several times, and drop some of the categorical features that you expect are likely to be missing in the test data.
As you learned in the previous chapter, the purpose of a model is to achieve generalization, and every modeling decision you will make throughout the model development process will be guided by validation metrics that seek to measure generalization performance. The goal of your validation protocol is to accurately estimate what your success metric of choice (such as accuracy) will be on actual production data. The reliability of that process is critical to building a useful model.
In chapter 5, we reviewed three common evaluation protocols:
Pick any one of these. In most cases, the first will work well enough. As you’ve learned, always be mindful of the representativeness of your validation set(s), and be careful not to have redundant samples between your training set and your validation set(s).
As you start working on the model itself, your initial goal is to achieve statistical power, as you saw in chapter 5: that is, to develop a small model that is capable of beating a simple baseline. At this stage, these are the three most important things you should focus on:
It’s often not possible to directly optimize for the metric that measures success on a problem. Sometimes there is no easy way to turn a metric into a loss function; loss functions, after all, need to be computable given only a mini-batch of data (ideally, a loss function should be computable for as little as a single data point) and must be differentiable (otherwise, we can’t use backpropagation to train our network). For instance, the widely used classification metric ROC AUC can’t be directly optimized. Hence, in classification tasks, it’s common to optimize for a proxy metric of ROC AUC, such as cross-entropy. In general, we can hope that the lower the cross-entropy gets, the higher the ROC AUC will be.
The following table can help you choose a last-layer activation, a loss function, and metrics for a few common problem types.
| Task | Last-layer activation | Loss function | Metrics |
|---|---|---|---|
| Binary classification | Sigmoid | Binary cross-entropy | Binary accuracy, ROC AUC |
| Multiclass, single-label classification | Softmax | Categorical cross-entropy | Categorical accuracy, Top-k categorical accuracy, ROC AUC |
| Multiclass, multilabel classification | Sigmoid | Binary cross-entropy | Binary accuracy, ROC AUC |
| Regression | None | Mean squared error | Mean absolute error |
For most problems, there are existing templates you can start from. You’re not the first person to try to build a spam detector, a music recommendation engine, or an image classifier. Be sure to research prior art to identify the feature engineering techniques and model architectures that are most likely to perform well on your task.
Note that it’s not always possible to achieve statistical power. If you can’t beat a simple baseline after trying multiple reasonable architectures, it may be that the answer to the question you’re asking isn’t present in the input data. Remember that you’re making two hypotheses:
These hypotheses may be false, in which case you must go back to the drawing board.
Once you’ve obtained a model that has statistical power, the question becomes, is your model sufficiently powerful? Does it have enough layers and parameters to properly model the problem at hand? For instance, a logistic regression model has statistical power on MNIST, but it wouldn’t be sufficient to solve the problem well. Remember that the universal tension in machine learning is between optimization and generalization; the ideal model is one that stands right at the border between underfitting and overfitting; between undercapacity and overcapacity. To figure out where this border lies, first we must cross it.
To figure out how big a model you’ll need, you must develop a model that overfits. This is fairly easy, as you learned in chapter 5:
Always monitor the training loss and validation loss, as well as the training and validation values for any metrics you care about. When you see that the model’s performance on the validation data begins to degrade, you’ve achieved overfitting.
Once you’ve achieved statistical power and you’re able to overfit, you know you’re on the right path. At this point, your goal becomes to maximize generalization performance.
This phase will take the most time: you’ll repeatedly modify your model, train it, evaluate on your validation data (not the test data, at this point), modify it again, and repeat, until the model is as good as it can get. Here are some things you should try:
It’s possible to automate a large chunk of this work by using automated hyperparameter tuning software, such as KerasTuner. We’ll cover this in chapter 18.
Be mindful of the following: every time you use feedback from your validation process to tune your model, you leak information about the validation process into the model. Repeated just a few times, this is innocuous; however, done systematically over many iterations, it will eventually cause your model to overfit to the validation process (even though no model is directly trained on any of the validation data). This makes the evaluation process less reliable.
Once you’ve developed a satisfactory model configuration, you can train your final production model on all the available data (training and validation) and evaluate it one last time on the test set. If it turns out that performance on the test set is significantly worse than the performance measured on the validation data, this may mean either that your validation procedure wasn’t reliable after all or that you began overfitting to the validation data while tuning the parameters of the model. In this case, you may want to switch to a more reliable evaluation protocol (such as iterated K-fold validation).
After your model has successfully cleared its final evaluation on the test set, it’s ready to be deployed and to begin its productive life.
Success and customer trust are about consistently meeting or exceeding people’s expectations; the actual system you deliver is only half of that picture. The other half is setting appropriate expectations before launch.
The expectations of nonspecialists toward AI systems are often unrealistic. For example, they might expect that the system “understands” its task and is capable of exercising human-like common sense in the context of the task. To address this, you should consider showing some examples of the failure modes of your model (for instance, show what incorrectly classified samples look like, especially those for which the misclassification seems surprising).
They might also expect human-level performance, especially for processes that were previously handled by people. Most machine learning models, because they are (imperfectly) trained to approximate human-generated labels, get nowhere close. You should clearly convey model performance expectations. Avoid using abstract statements like “The model has 98% accuracy” (which most people mentally round up to 100%), and instead talk about, for instance, false-negative rates and false-positive rates. You could say, “With these settings, the fraud detection model would have a 5% false-negative rate and a 2.5% false-positive rate. Every day, an average of 200 valid transactions would be flagged as fraudulent and sent for manual review, and an average of 14 fraudulent transactions would be missed. An average of 266 fraudulent transactions would be correctly caught.” Clearly relate the model’s performance metrics to business goals.
You should also be sure to discuss with stakeholders the choice of key launch parameters: for instance, the probability threshold at which a transaction should be flagged (different thresholds will produce different false-negative and false-positive rates). Such decisions involve tradeoffs that can only be handled with a deep understanding of the business context.
A machine learning project doesn’t end when you arrive at a script that can save a trained model. You rarely put into production the exact same model object that you manipulated during training.
First, you may want to export your model to something other than R or Python:
Second, because your production model will only be used to output predictions (a phase called inference), rather than for training, you have room to perform various optimizations that can make the model faster and reduce its memory footprint. Let’s take a quick look at the available model deployment options.
Perhaps the easiest way to turn a model into a product is to serve it online via a REST API. There are a number of libraries out there for making this happen. Keras supports two of the most popular approaches out of the box: TensorFlow Serving and ONNX (Open Neural Network Exchange). Both libraries operate by lifting all model weights and a computation graph outside of the R program so you can serve it from a number of different environments (for example, a C++ server). If this sounds a lot like the compilation mechanism discussed in chapter 3, you are spot on. TensorFlow Serving is essentially a library for serving tf_function() computation graphs with a specific set of saved weights.
Keras allows access to both TensorFlow Serving and ONNX via an easy-to-use export_savedmodel() method for Keras models. Here’s how this works for TensorFlow Serving:
A similar flow exists for ONNX:
1model |> export_savedmodel("path/to/location", format = "onnx")reticulate::py_require("onnxruntime")
onnxruntime <- reticulate::import("onnxruntime")
ort_session <- onnxruntime$InferenceSession("path/to/location")
predictions <- ort_session$run(NULL, input_data)Posit Connect supports directly uploading an exported Keras model for serving:
export_savedmodel(model, "model-for-serving")
rsconnect::deployTFModel("model-for-serving")You should use this deployment setup in the following cases:
For instance, the image search engine project, the music recommender system, the credit card fraud detection project, and the satellite imagery project are all good fits for serving via a REST API.
An important question when deploying a model as a REST API is whether to host the code on your own or use a fully managed third-party cloud service. For instance, Cloud AI Platform, a Google product, lets you simply upload your TensorFlow model to Google Cloud Storage (GCS) and gives you an API endpoint to query it. It takes care of many practical details such as batching predictions, load balancing, and scaling.
Sometimes you may need your model to live on the same device that runs the application that uses it—maybe a smartphone, an embedded ARM CPU on a robot, or a microcontroller on a tiny device. For instance, perhaps you’ve seen a camera capable of automatically detecting people and faces in the scenes you pointed it at: that was probably a small deep learning model running directly on the camera.
You should use this setup in these cases:
For instance, our spam detection model will need to run on the end user’s smartphone as part of the chat app, because messages are end-to-end encrypted and thus cannot be read by a remotely hosted model. Likewise, the bad-cookie-detection model has strict latency constraints and will need to run at the factory. Fortunately, in this case we don’t have any power or space constraints, so we can run the model on a GPU.
To deploy a Keras model on a smartphone or embedded device, you can again use the export_savedmodel() method to create a TensorFlow or ONNX save of your model, including the computation graph. TensorFlow Lite (https://www.tensorflow.org/lite) is a framework for efficient on-device deep learning inference that runs on Android and iOS smartphones, as well as ARM CPUs, Raspberry Pi, or certain microcontrollers. It uses the same TensorFlow save model format as TensorFlow Serving. The ONNX runtime can also run on mobile devices.
Deep learning is often used in browser-based or desktop-based JavaScript applications. Although it is usually possible to have the application query a remote model via a REST API, there can be key advantages in having the model run directly in the browser, on the user’s computer (utilizing GPU resources if available).
Use this setup in these cases:
Of course, you should only go with this option if your model is small enough that it won’t hog the CPU, GPU, or RAM of your user’s laptop or smartphone. In addition, because the entire model will be downloaded to the user’s device, you should make sure that nothing about the model needs to stay confidential. Be mindful that given a trained deep learning model, it is usually possible to recover some information about the training data: better not to make your trained model public if it was trained on sensitive data.
To deploy a model in JavaScript, the TensorFlow ecosystem includes TensorFlow.js (https://www.tensorflow.org/js), and ONNX supports a native JavaScript runtime. TensorFlow.js even implements almost all of the Keras API (it was originally developed under the working name WebKeras) as well as many lower-level TensorFlow APIs. You can easily import a saved Keras model into TensorFlow.js to query it as part of your browser-based JavaScript app or your desktop Electron app.
Optimizing your model for inference is especially important when deploying in an environment with strict constraints on available power and memory (smartphones and embedded devices) or for applications with low latency requirements. You should always seek to optimize your model before importing it into TensorFlow.js or exporting it to TensorFlow Lite.
There are two popular optimization techniques you can apply:
float32) weights. However, it’s possible to quantize weights to 8-bit signed integers (int8) to get an inference-only model that’s four times smaller but remains near the accuracy of the original model. Keras models come with a built-in quantize_weights() API that can help with this. Simply call quantize_weights(model, "int8") to compress each weight in your model to a single byte.You’ve exported an inference model, you’ve integrated it into your application, and you’ve done a dry run on production data—and the model behaved exactly as you expected. You’ve written unit tests as well as logging and status monitoring code—perfect. Now, it’s time to press the big red button and deploy to production.
Even this is not the end. Once you’ve deployed a model, you need to keep monitoring its behavior, its performance on new data, its interaction with the rest of the application, and its eventual effect on business metrics:
No model lasts forever. You’ve already learned about concept drift: over time, the characteristics of your production data will change, gradually degrading the performance and relevance of your model. The lifespan of your music recommender system will be counted in weeks. For the credit card fraud detection systems, it would be days. A couple of years is the best case for the image search engine.
As soon as your model has launched, you should be getting ready to train the next generation that will replace it:
This concludes the universal workflow of machine learning—and it’s a lot of things to keep in mind. It takes time and experience to become an expert, but don’t worry, you’re already a lot wiser than you were a few chapters ago. You are now familiar with the big picture: the entire spectrum of what machine learning projects entail. Although most of this book will focus on model development, you’re now aware that it’s only one part of the entire workflow. Always keep the big picture in mind!