Avoiding Common Pitfalls
All models are wrong, but some are useful.
The above quote is also nicely exemplified by this xkcd comic:

A supervised learning model tries to infer the relationship between some inputs and outputs from the given exemplary data points. What kind of relation will be found is largely determined by the chosen model type and its internal optimization algorithm, however, there is a lot we can (and should) do to make sure what the algorithm comes up with is not blatantly wrong.
- What do we want?
-
A model that …
-
… makes accurate predictions
-
… for new data points
-
… for the right reasons
-
… even when the world keeps on changing.
- What can go wrong?
-
-
Evaluating the model with an inappropriate evaluation metric (e.g., accuracy instead of balanced accuracy for a classification problem with an unequal class distribution), thereby not noticing the subpar performance of a model (e.g., compared to a simple baseline).
-
Using a model that can not capture the ‘input → output’ relationship (due to underfitting) and does not generate useful predictions.
-
Using a model that overfit on the training data and therefore does not generalize to new data points.
-
Using a model that abuses spurious correlations.
-
Using a model that discriminates.
-
Not monitoring and retraining the model regularly on new data.
-
Below you find a quick summary of what you can do to avoid these pitfalls and we’ll discuss most these points in more detail in the following sections.
- And after the model was deployed…
-
ML fails silently! I.e., even if all predictions are wrong, the program does not simply crash with some error message.
→ Need constant monitoring to detect changes that lead to a deteriorating performance!
One of the biggest problems in practice: Data and Concept Drifts:
The model performance quickly decays when the distribution of the data used for training \(P_{train}(y, X)\) is different from the data the model encounters in production \(P_{prod}(y, X)\). Such a discrepancy can be due to
-
Data drift: distribution of input features \(X\) changes, i.e., \(P_{train}(X) \neq P_{prod}(X)\)
-
Concept drift: input/output relationship \(X \to y\) changes, i.e., \(P_{train}(y|X) \neq P_{prod}(y|X)\)
Example: From the production settings incl. the size of a produced part (\(X\)) we want to predict whether the part is scrap or okay (\(y\)):
-
Data drift: The company used to manufacture only small parts, now they also produce larger parts.
-
Concept drift: The company used to produce 10% scrap parts, but after some maintenance was done on the machine, the same production settings now result in only 5% scrap.
Possible reasons for data or concept drifts:
-
New users behave differently: For example, the product is now used by a younger generation or the business recently expanded to a different country.
-
Trends and seasonality: For example, certain products are bought primarily in a specific season (e.g., winter cloths) or new styles result in the same customers now buying different products because they are “in”.
-
The process structure / setup changed: For example, in a new version of a device, a sensor has moved to a different location, but still logs values under the same variable name. Or, due to a software update, values that were previously logged as integers are suddenly converted to strings (although this should actually raise an error).
-
Feedback loop, where the presence of an ML model alters user behavior: For example, spammers change their strategy to evade the spam filter.
-
You used different preprocessing steps on the training and production data, for example, because you did not properly document all of the transformations that were applied to the initial dataset.
These changes can either be gradual (e.g., languages change gradually as new words are coined; a camera lens gets covered with dust over time), or they can come as a sudden shock (e.g., someone cleans the camera lens; when the COVID-19 pandemic hit, suddenly a lot of people switched to online shopping, which tripped up the credit card fraud detection systems). |
Mitigation Strategies:
The best way to counteract data and concept drifts is to frequently retrain the model on new data. This can either happen on a schedule (e.g., every weekend, depending on how quickly the data changes) or when your monitoring system (see below) raises an alert because it detected drifts in the inputs or a deteriorating model performance.
Data drifts can be seen as an opportunity to extend the training set with more diverse samples. Concept drifts, on the other hand, entail the need to remove samples from the original training set that do not conform to the novel input/output relation before retraining the model. While traditional ML models typically need to be retrained from scratch, neural network models can also be fine-tuned on newly collected data, however, this is only useful when faced with minor data drifts, not drastic concept drifts.
Additionally, known changes can also be included in the model as additional features (e.g., seasons).
Possible components of a monitoring system:
-
Use statistical tests to detect skewed distributions of individual features:
-
Kullback-Leibler divergence
-
Jensen-Shannon divergence
-
Kolmogorov-Smirnov (K-S) test
-
-
Use novelty detection or clustering to identify data points that are different from the training samples. Even if a sample’s individual feature values are still in a normal range, this particular combination of feature values can be far from what the model encountered during training.
-
Check if there is a difference between the predicted and true (training) label frequencies. For example, if in reality usually about 10% of our produced products are faulty, but the model suddenly predicts that 50% of the products are faulty, then something is probably off.
-
Check whether the confidence scores of the model predictions (i.e., the probability for a class, not the predicted class label) get lower, which indicates that new samples are closer to the model’s decision boundary than the training samples.
-
Check the error metrics of the model on new data (only possibly if you continuously collected new labeled data).
-
After retraining the model on new data, check if the feature importances changed, which indicates that it might be time to select different features for the prediction.
These checks can be combined with a sliding window approach, for example, every hour the data collected in the last 48 hours is compared to the training data. If any of the monitoring values exceed some predefined threshold, the system triggers an alert.
Additionally:
-
Validate the input data schema, i.e., check that data types and value ranges (incl. missing values / NaNs) match those encountered in the training data.
-
Log known external events (!!), e.g., maintenance on a machine.
The "ML fails silently" part also applies to bugs in your code: Especially when you just started with ML it often happens that your results seem fine (maybe just a little too good), but you have a subtle bug somewhere that doesn’t cause your program to crash, but just calculates something slightly wrong somewhere. These issues can be very hard to notice, so always triple-check your code and if you can, write unit tests for individual functions to make sure they do what you expect them to do. |
You might also want to have a look at Google’s rules of machine learning (the first one being: "Don’t be afraid to launch a product without machine learning.") |