Avoiding Common Pitfalls

All models are wrong, but some are useful.
— George E. P. Box

The above quote is also nicely exemplified by this xkcd comic:

image

A supervised learning model tries to infer the relationship between some inputs and outputs from the given exemplary data points. What kind of relation will be found is largely determined by the chosen model type and its internal optimization algorithm, however, there is a lot we can (and should) do to make sure what the algorithm comes up with is not blatantly wrong.

What do we want?

A model that …​

  • …​ makes accurate predictions

  • …​ for new data points

  • …​ for the right reasons

  • …​ even when the world keeps on changing.

What can go wrong?
  • Evaluating the model with an inappropriate evaluation metric (e.g., accuracy instead of balanced accuracy for a classification problem with an unequal class distribution), thereby not noticing the subpar performance of a model (e.g., compared to a simple baseline).

  • Using a model that can not capture the ‘input → output’ relationship (due to underfitting) and does not generate useful predictions.

  • Using a model that overfit on the training data and therefore does not generalize to new data points.

  • Using a model that abuses spurious correlations.

  • Using a model that discriminates.

  • Not monitoring and retraining the model regularly on new data.

Below you find a quick summary of what you can do to avoid these pitfalls and we’ll discuss most these points in more detail in the following sections.

Before training a model
  • Select the right inputs: ask a domain expert which variables could have a causal influence on the output; possibly compute additional, more informative features from the original measurements (→ feature engineering).

  • Sanity check: Does the dataset contain samples with the same inputs but different outputs? ⇒ Some important features might be missing or the targets are very noisy, e.g., due to inconsistent annotations — fix this first!

  • Try a simple model (linear model or decision tree) — this can serve as a reasonable baseline when experimenting with more complex models.

  • Think about the structure of the problem and what type of model might be appropriate to learn the presumed ‘input → output’ relationship. For example, if the problem is clearly nonlinear, the chosen model type also needs to be complex enough to at least in principle be able to pick up on this relation (i.e., such that the model does not underfit, see below). A lot of domain knowledge can also be put into the design of neural network architectures.

  • Make sure the data satisfies the model’s assumptions ⇒ for pretty much all models except decision trees and models based on decision trees, like random forests, the data should be approximately normally distributed.

  • Make sure you’ve set aside a representative test set to evaluate the final model and possibly a validation set for model selection and hyperparameter tuning.

After the model was trained
  • Evaluate the model with a meaningful evaluation metric, especially when the classes in the dataset are not distributed evenly (→ balanced accuracy).

  • Check that the model can interpolate, i.e., that it generalizes to unseen data points from the same distribution as the training set and does not over- or underfit. Please note that this does not ensure that it can also extrapolate, i.e., that it has learned the true causal relation between inputs and outputs and can generate correct predictions for data points outside of the training domain!

  • Carefully analyze the model’s prediction errors to check for systematic errors, which can indicate that the data violates your initial assumptions. For example, in a classification task the performance for all classes should be approximately the same, while in a regression task the residuals should be independent.

  • Verify that the model does not discriminate. Due to the large quantities of data used to train ML models, it is not always possible to ensure that the training data does not contain any systematic biases (e.g., ethnicity/gender stereotypes) that a model might pick up on, but it is important to test the model on a controlled test set and individual data slices to catch any discrimination before the model is deployed in production.

  • Interpret the model and explain its predictions: Does it use the features you or a domain expert expected it to use or does it make predictions based on any spurious correlations?

  • If necessary, use model editing or assertions to fix incorrect model predictions. For example, you can manually alter the rules learned by a decision tree or implement additional business rules that override model predictions or act as sanity checks (e.g., a predicted age should never be negative).

Please note that these steps represent an iterative workflow, i.e., after training some model and analyzing its performance one often goes back to the beginning and, e.g., selects different features or tries a more complex model to improve the performance.

And after the model was deployed…​
ML fails silently! I.e., even if all predictions are wrong, the program does not simply crash with some error message.
→ Need constant monitoring to detect changes that lead to a deteriorating performance!

One of the biggest problems in practice: Data and Concept Drifts:
The model performance quickly decays when the distribution of the data used for training \(P_{train}(y, X)\) is different from the data the model encounters in production \(P_{prod}(y, X)\). Such a discrepancy can be due to

  • Data drift: distribution of input features \(X\) changes, i.e., \(P_{train}(X) \neq P_{prod}(X)\)

  • Concept drift: input/output relationship \(X \to y\) changes, i.e., \(P_{train}(y|X) \neq P_{prod}(y|X)\)

Example: From the production settings incl. the size of a produced part (\(X\)) we want to predict whether the part is scrap or okay (\(y\)):

  • Data drift: The company used to manufacture only small parts, now they also produce larger parts.

  • Concept drift: The company used to produce 10% scrap parts, but after some maintenance was done on the machine, the same production settings now result in only 5% scrap.

Possible reasons for data or concept drifts:

  • New users behave differently: For example, the product is now used by a younger generation or the business recently expanded to a different country.

  • Trends and seasonality: For example, certain products are bought primarily in a specific season (e.g., winter cloths) or new styles result in the same customers now buying different products because they are “in”.

  • The process structure / setup changed: For example, in a new version of a device, a sensor has moved to a different location, but still logs values under the same variable name. Or, due to a software update, values that were previously logged as integers are suddenly converted to strings (although this should actually raise an error).

  • Feedback loop, where the presence of an ML model alters user behavior: For example, spammers change their strategy to evade the spam filter.

  • You used different preprocessing steps on the training and production data, for example, because you did not properly document all of the transformations that were applied to the initial dataset.

These changes can either be gradual (e.g., languages change gradually as new words are coined; a camera lens gets covered with dust over time), or they can come as a sudden shock (e.g., someone cleans the camera lens; when the COVID-19 pandemic hit, suddenly a lot of people switched to online shopping, which tripped up the credit card fraud detection systems).

Mitigation Strategies:
The best way to counteract data and concept drifts is to frequently retrain the model on new data. This can either happen on a schedule (e.g., every weekend, depending on how quickly the data changes) or when your monitoring system (see below) raises an alert because it detected drifts in the inputs or a deteriorating model performance.
Data drifts can be seen as an opportunity to extend the training set with more diverse samples. Concept drifts, on the other hand, entail the need to remove samples from the original training set that do not conform to the novel input/output relation before retraining the model. While traditional ML models typically need to be retrained from scratch, neural network models can also be fine-tuned on newly collected data, however, this is only useful when faced with minor data drifts, not drastic concept drifts.
Additionally, known changes can also be included in the model as additional features (e.g., seasons).

Possible components of a monitoring system:

  • Use statistical tests to detect skewed distributions of individual features:

    • Kullback-Leibler divergence

    • Jensen-Shannon divergence

    • Kolmogorov-Smirnov (K-S) test

  • Use novelty detection or clustering to identify data points that are different from the training samples. Even if a sample’s individual feature values are still in a normal range, this particular combination of feature values can be far from what the model encountered during training.

  • Check if there is a difference between the predicted and true (training) label frequencies. For example, if in reality usually about 10% of our produced products are faulty, but the model suddenly predicts that 50% of the products are faulty, then something is probably off.

  • Check whether the confidence scores of the model predictions (i.e., the probability for a class, not the predicted class label) get lower, which indicates that new samples are closer to the model’s decision boundary than the training samples.

  • Check the error metrics of the model on new data (only possibly if you continuously collected new labeled data).

  • After retraining the model on new data, check if the feature importances changed, which indicates that it might be time to select different features for the prediction.

These checks can be combined with a sliding window approach, for example, every hour the data collected in the last 48 hours is compared to the training data. If any of the monitoring values exceed some predefined threshold, the system triggers an alert.

Additionally:

  • Validate the input data schema, i.e., check that data types and value ranges (incl. missing values / NaNs) match those encountered in the training data.

  • Log known external events (!!), e.g., maintenance on a machine.

The "ML fails silently" part also applies to bugs in your code: Especially when you just started with ML it often happens that your results seem fine (maybe just a little too good), but you have a subtle bug somewhere that doesn’t cause your program to crash, but just calculates something slightly wrong somewhere. These issues can be very hard to notice, so always triple-check your code and if you can, write unit tests for individual functions to make sure they do what you expect them to do.
You might also want to have a look at Google’s rules of machine learning (the first one being: "Don’t be afraid to launch a product without machine learning.")