It can be useful and prudent to investigate the reasoning behind a machine learning model’s predictions. For example, if you have a classifier that aims to determine whether a photo contains a hot dog, you could take one of its positive predictions and ask, “what is it about this particular photo that most strongly suggests a hot dog?” Or in a more serious context, if a model is meant to identify customers who are good prospects for purchasing a given product, you might look at each of the best ratings and ask, “what particular factor makes this customer a good lead?” This could serve as a kind of sanity check on the model’s rationale, or it could allow you to focus on customers who are good leads for only specific reasons.
For one of the large models I’m in charge of at work, which is fairly similar in purpose to the serious example mentioned above, a stakeholder recently asked whether it would be possible to provide some justifications about the customers it ranks as the best leads. Because of this, and my own preexisting interest, I’ve been doing a lot of exploration and experimentation with this topic in the past couple of days.
In this post, we’ll see how to use a new Python package called LIME with a random forest classifier, to explain the model’s predictions. Then I’ll introduce an alternative method that gives very similar explanations, but works hundreds of times faster.
We will be using the the following tools:
import pandas as pd from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier from lime.lime_tabular import LimeTabularExplainer from scipy.sparse import hstack
All code from this post can be found on my GitHub.
Experimental data set
The code presented later (starting with “Explanations from LIME”) should work for any situation in which we have a pandas dataframe X of features, a pandas series y of targets (1’s and 0’s), and a corresponding trained scikit-learn random forest classifier. But to demonstrate the code in a simple setting, we will use one of scikit-learn’s built-in “toy” datasets, which can be loaded as follows:
from sklearn.datasets import load_wine dataset = load_wine() X = pd.DataFrame(dataset.data, columns = dataset.feature_names) y = dataset.target
This data set apparently consists of a list of 178 wines, described by 13 features with mostly reasonable descriptions like “alcohol,” “magnesium,” “flavanoids,” “color_intensity,” and “hue.” Each wine is labeled in the target list y with a 0, 1, or 2, so that the wines fall into 3 classes of some kind. (I must confess that I’ve never touched this data set until today.)
We’ll arbitrarily split the data into a training set of 100 wines and a testing set of the other 78:
Xtrain, Xtest, ytrain, ytest = train_test_split( X, y, stratify = y, train_size = 100,
random_state = 2018)
The random_state parameter should not usually be used, unless we want our code to be exactly reproducible. I use it here so that if you run my code yourself, you should see very close to the same results. Throughout all code in this post, any line that sets random_state = 2018 should be deleted before the code is applied to any other machine learning problem of your own.
I should also note that in my own real-life use case, mentioned before, I do not have a testing set per se; I have a features set and target list from past months for training a classifier, and a features set for an upcoming month to feed into the trained classifier and obtain predictions and explanations. In our toy example here, we will not use the target list from the testing set (ytest) for anything at all, except to check the truth of our predictions.
Once we train a classifier and start digging into the explanations given for its predictions, we will to want to be able to judge whether those explanations make sense. So before we go on, let’s explore the training data. The following command gives an indication of how the distributions of the features vary across the three classes.
The resulting output is shown below. For example, the measurement of alcohol for most wines in class 0 varies from 13.208 to 14.356. In class 1, it varies from 11.787 to 13.133. In class 2, it varies from 12.486 to 13.98. What can we do with this information? Well, one thing we can say is that a wine with alcohol measurement below 13 is probably not in class 0.
Examining all the summaries below, which feature is most valuable for determining whether a wine belongs to class 0?
In my opinion, the proline measurement seems to be the single most useful piece of information to decide whether a wine is in class 0. If proline is above 840 or so, the wine is almost certainly class 0.
The color intensity looks useful for finding class 1 wines. If a wine has color intensity below 4.1 or so, we should be pretty confident that it’s in class 1.
There are a few good features for identifying class 2 wines – we could require that flavanoids are below 1.2, or hue is below 0.86, or the od280/od315 measurement is below 2. (If all of these conditions are met, we would be very sure indeed.)
The topic at hand and the code below (starting with “Explanations from LIME”) are largely oriented toward binary classification problems, so we’ll treat the wine example as three separate binary problems. For each class, we’ll train a separate random forest classifier to determine whether each data point belongs to that class or not:
clf =  for clas in range(3): clf.append( RandomForestClassifier( n_estimators = 100, n_jobs = -1,
random_state = 2018).fit(Xtrain, ytrain == clas) )
Explanations from LIME
There has been a good deal of buzz in the world of machine learning in the past year or so about a new method for explaining the predictions of advanced models, which is what first kindled my interest in the topic. The new method is called “Local Interpretable Model-Agnostic Explanations,” or LIME for short.
- “Local” is a technical term referring to the main mathematical idea behind the method, which you can read more about here.
- “Interpretable” conveys that that the method explains predictions in a format that humans can read and understand.
- “Model-agnostic” refers to the method’s applicability to any machine learning classifier, regardless of the type of algorithm underlying it (e.g. SVM, random forests, deep learning, etc).
To apply LIME to our example data set and classifiers, we must first set up an “explainer” using the training data:
explainer = LimeTabularExplainer( Xtrain.values, feature_names = Xtrain.columns,
random_state = 2018)
Now, as of this writing, LIME is only designed to explain one prediction at a time. For example, if my classifier for class 1 predicts that wine #88 belongs in class 1, here’s how I could ask why:
explainer.explain_instance( Xtest.loc[88, :], clf.predict_proba ).as_list()
The output of this command is a list of reasons (10 reasons by default) with a measure of the strength of each one:
[('color_intensity <= 3.14', 0.3897), ('alcohol <= 12.41', 0.1860), ('660.00 < proline <= 973.75', -0.08321), ('1.88 < malic_acid <= 3.15', -0.04421), ('magnesium <= 88.00', 0.04324), ('1.01 < flavanoids <= 1.89', 0.02799), ('2.35 < ash <= 2.57', -0.02716), ('0.97 < hue <= 1.12', 0.02564), ('1.85 < od280/od315_of_diluted_wines <= 2.76', 0.01776), ('alcalinity_of_ash > 21.50', 0.01516)]
In this case, the two strongest reasons for predicting wine #88 is in class 1 are:
- color intensity ≤ 3.14
- alcohol ≤ 12.41
Note that some of the reasons in the list above have negative strength. This means that they are actually reasons that wine #88 might not be in class 1.
The following code will allow us to use the LIME explainer on an entire data set, to find the top explanations (if any) in favor of a positive prediction in every case.
def explain_row(clf, row, num_reasons = 2): ''' Produce LIME explanations for a single row of data. * `clf` is a binary classifier (with a predict_proba method), * `row` is a row of features data, * `num_reasons` (default 2) is the number of reasons/explanations to be produced. ''' exp = [ exp_pair for exp_pair in # Get each explanation explainer.explain_instance( # from the LIME explainer row, clf.predict_proba, # for given row and classifier labels = , # and label 1 ("positives") num_features = num_reasons # for `num_reasons` explanations ).as_list() if exp_pair > 0 # only for pos. explanations ][:num_reasons] # Fill in any missing explanations with blanks exp += [''] * (num_reasons - len(exp)) return exp def predict_explain(rf, X, num_reasons = 2): ''' Produce scores and LIME explanations for every row in a data frame. * `rf` is a binary classifier with a predict_proba method, * `X` is the features data frame, * `num_reasons` (default 2) is the number of reasons/explanations to be produced for each row. ''' # Prepare the structure to be returned pred_ex = X[] # Get the scores from the classifier pred_ex['SCORE'] = rf.predict_proba(X)[:,1] # Get the reasons/explanations for each row cols = zip( *X.apply( lambda x: explain_row(rf, x, num_reasons), axis = 1, raw = True ) ) # Return the results for n in range(num_reasons): pred_ex['REASON%d' % (n+1)] = next(cols) return pred_ex
Now, the following command produces the explanations for predictions from the class 0 model, showing the results only for the top 20 scores for class 0.
predict_explain(clf, Xtest).assign( TRUE_CLASS = ytest ).sort_values('SCORE', ascending = False).head(20)
|58||1||proline > 973.75||flavanoids > 2.81||0|
|57||1||proline > 973.75||flavanoids > 2.81||0|
|0||1||proline > 973.75||flavanoids > 2.81||0|
|26||0.98||proline > 973.75||flavanoids > 2.81||0|
|49||0.98||proline > 973.75||flavanoids > 2.81||0|
|31||0.98||proline > 973.75||flavanoids > 2.81||0|
|51||0.96||proline > 973.75||flavanoids > 2.81||0|
|47||0.96||proline > 973.75||flavanoids > 2.81||0|
|9||0.95||proline > 973.75||flavanoids > 2.81||0|
|13||0.95||proline > 973.75||flavanoids > 2.81||0|
|7||0.95||proline > 973.75||alcohol > 13.74||0|
|28||0.94||flavanoids > 2.81||total_phenols > 2.80||0|
|20||0.92||flavanoids > 2.81||total_phenols > 2.80||0|
|17||0.91||proline > 973.75||flavanoids > 2.81||0|
|22||0.91||proline > 973.75||flavanoids > 2.81||0|
|6||0.9||proline > 973.75||2.21 < total_phenols <= 2.80||0|
|50||0.86||proline > 973.75||flavanoids > 2.81||0|
|32||0.81||proline > 973.75||13.13 < alcohol <= 13.74||0|
|41||0.8||proline > 973.75||2.21 < total_phenols <= 2.80||0|
|38||0.79||proline > 973.75||2.21 < total_phenols <= 2.80||0|
As my earlier intuition suggested, when a wine is predicted to be in class 0, it is usually due to having a high proline measurement. Having high flavanoids is also a common reason; looking back at the distributions table, it makes sense that most wines with flavanoids > 2.81 would be in class 0. Wine #20 is one of the cases where flavanoids, rather than proline, was the main explanation for class 0. If we examine that data point —
— it turns out that its proline measurement is only 780, which isn’t high enough to rule out class 2 with much certainty.
Repeating the command above, but for clf, produces the explanations for predictions from the class 1 model, again showing the results only for the top 20 scores.
|128||1||color_intensity <= 3.14||proline <= 495.00||1|
|97||1||color_intensity <= 3.14||proline <= 495.00||1|
|86||1||color_intensity <= 3.14||alcohol <= 12.41||1|
|117||0.99||color_intensity <= 3.14||proline <= 495.00||1|
|113||0.99||color_intensity <= 3.14||alcohol <= 12.41||1|
|106||0.99||alcohol <= 12.41||ash <= 2.20||1|
|114||0.99||color_intensity <= 3.14||proline <= 495.00||1|
|78||0.98||alcohol <= 12.41||malic_acid <= 1.64||1|
|99||0.98||color_intensity <= 3.14||proline <= 495.00||1|
|94||0.98||alcohol <= 12.41||proline <= 495.00||1|
|108||0.98||color_intensity <= 3.14||proline <= 495.00||1|
|105||0.96||color_intensity <= 3.14||proline <= 495.00||1|
|88||0.94||color_intensity <= 3.14||alcohol <= 12.41||1|
|109||0.93||color_intensity <= 3.14||alcohol <= 12.41||1|
|110||0.93||color_intensity <= 3.14||alcohol <= 12.41||1|
|102||0.93||color_intensity <= 3.14||proline <= 495.00||1|
|74||0.91||alcohol <= 12.41||malic_acid <= 1.64||1|
|124||0.91||color_intensity <= 3.14||proline <= 495.00||1|
|104||0.9||color_intensity <= 3.14||ash <= 2.20||1|
|129||0.89||color_intensity <= 3.14||alcohol <= 12.41||1|
The reasons for predicting class 1 usually involve low color intensity, as intuition suggested before.
Repeating the command again with clf, here are the top 20 results:
|176||1||hue <= 0.82||flavanoids <= 1.01||2|
|155||1||hue <= 0.82||flavanoids <= 1.01||2|
|174||1||hue <= 0.82||flavanoids <= 1.01||2|
|148||0.98||hue <= 0.82||flavanoids <= 1.01||2|
|169||0.94||hue <= 0.82||flavanoids <= 1.01||2|
|163||0.94||hue <= 0.82||flavanoids <= 1.01||2|
|171||0.91||hue <= 0.82||flavanoids <= 1.01||2|
|131||0.9||hue <= 0.82||od280/od315 <= 1.85||2|
|170||0.85||hue <= 0.82||flavanoids <= 1.01||2|
|136||0.81||hue <= 0.82||flavanoids <= 1.01||2|
|149||0.81||hue <= 0.82||od280/od315 <= 1.85||2|
|159||0.8||hue <= 0.82||od280/od315 <= 1.85||2|
|150||0.79||hue <= 0.82||od280/od315 <= 1.85||2|
|151||0.77||hue <= 0.82||od280/od315 <= 1.85||2|
|154||0.75||hue <= 0.82||flavanoids <= 1.01||2|
|140||0.72||hue <= 0.82||flavanoids <= 1.01||2|
|162||0.67||flavanoids <= 1.01||malic_acid > 3.15||2|
|142||0.67||flavanoids <= 1.01||malic_acid > 3.15||2|
|141||0.66||hue <= 0.82||flavanoids <= 1.01||2|
|130||0.62||hue <= 0.82||od280/od315 <= 1.85||2|
The three most common reasons given in this case are in extreme agreement with earlier intuition.
Difficulties with LIME
As noted before, LIME is only designed to explain one prediction at a time, and the process is time-consuming. Each of the sets of explanations shown above required around 45 seconds to generate on my laptop.
Part of the process for producing each explanation involves taking a large number of random samples in feature space – 5000 samples by default – and fitting a linear model. (Again, a deeper description of the process can be read here.) The time cost of this sampling can be reduced by adjusting the num_samples parameter of the explain_instance method. For example, reducing the number of samples to 50 seems to cut the time required to around 10 seconds. However, this has tangible effects on the explanations given; you might say the resulting explanations are “low fidelity.”
In any case, even a fourfold decrease in the time needed for LIME to generate explanations is not good enough. With each explanation requiring about ⅛ of a second, even a modest-sized machine learning project would require a day or more to explain all its predictions. In my own main use case, the model produces around 12 million predictions per month, and due to the size of the feature set, the low-fidelity LIME approach requires around 1 full second per explanation. Even if I only generate reasons for, say, the top 10% of scores, the process would require two weeks!
Explanations by tree interpretation
A couple of years before LIME’s debut, Ando Saabas released a beautiful little package called treeinterpreter, apparently with a purpose quite similar to LIME’s, but restricted in its scope to tree-based models, such as random forests. There is a great explanation of the treeinterpreter package here.
In short, the treeinterpreter traces the path of each data point through the decision tree(s) of a given machine learning model, recording how each feature contributes to the data point’s final score at each tree split along the way. At the end, the feature(s) with the highest total contributions constitute the “explanation” of the score. For example, the explanation given for the wine that is ranked highest by the class 2 model might simply be “hue and flavanoids.”
The treeinterpreter package goes a long way toward producing LIME-like explanations for random forests (and other tree-based classifiers), and I’ve made some simple tweaks that, I believe, take it the rest of the way to full functionality.
I have adapted the treeinterpreter package, expanding its functionality so that when it traces each data point through a tree, it keeps track of any potentially harmful thresholds encountered. As a concrete example, consider the following decision tree diagram, borrowed from the explanatory page I linked before:
Represented in red is the path of a data point with RM > 6.94 and RM ≤ 7.44 and NOX > 0.66. The fact that RM > 6.94 adds value to the final prediction (it goes from 22.60 to 37.42 at that split). The fact that RM ≤ 7.44 takes away value. RM is the most important feature for this data point, since it contributes the most positive value to the final prediction. However, if RM had been less than or equal to 6.94, the feature would have been less valuable. This is what I mean by a “potentially harmful” threshold. For this data point’s score, we report the following explanation: RM > 6.94.
Here is another example:
For this data point, we have RM > 6.94 and RM > 7.44 and RM ≤ 8.75. Each of these facts adds some value to the final prediction (it increases at each split), so those thresholds are all “potentially harmful.” For this data point’s score, we report this explanation: 7.44 < RM ≤ 8.75.
My adaptation of the treeinterpreter package applies the logic demonstrated above to an entire data set to produce explanations for its positive predictions. I call the package “tree_explainer,” and it can be found here on my GitHub. Let’s apply it to the example at hand.
Using tree_explainer to generate explanations for the best 20 scores from the class 0 model:
import tree_explainer tree_explainer.predict_explain(clf, Xtest).assign( TRUE_CLASS = ytest ).sort_values('SCORE', ascending = False).head(20)
|58||1||proline > 1010.00||2.57 < total_phenols <= 3.41||0|
|57||1||proline > 1010.00||2.57 < total_phenols <= 3.28||0|
|0||1||proline > 1010.00||2.57 < total_phenols <= 3.28||0|
|26||0.98||proline > 1010.00||2.57 < total_phenols <= 3.28||0|
|49||0.98||proline > 1045.00||2.66 < total_phenols <= 3.28||0|
|31||0.98||proline > 1045.00||2.66 < total_phenols <= 3.28||0|
|51||0.96||proline > 1082.50||flavanoids > 2.70||0|
|47||0.96||proline > 895.00||2.57 < total_phenols <= 3.28||0|
|9||0.95||proline > 1010.00||2.57 < total_phenols <= 3.28||0|
|13||0.95||proline > 1010.00||flavanoids > 2.70||0|
|7||0.95||proline > 1010.00||2.57 < total_phenols <= 3.28||0|
|28||0.94||proline > 895.00||flavanoids > 2.93||0|
|20||0.92||2.57 < total_phenols <= 3.28||flavanoids > 2.97||0|
|17||0.91||proline > 1122.50||2.66 < total_phenols <= 3.28||0|
|22||0.91||proline > 1010.00||2.57 < total_phenols <= 3.28||0|
|6||0.9||proline > 1082.50||alcohol > 13.38||0|
|50||0.86||proline > 1045.00||2.57 < total_phenols <= 3.28||0|
|32||0.81||proline > 970.00||2.67 < flavanoids <= 2.79||0|
|41||0.8||proline > 1010.00||flavanoids > 2.66||0|
|38||0.79||proline > 1017.50||flavanoids > 2.62||0|
From the class 1 model:
tree_explainer.predict_explain(clf, Xtest).assign( TRUE_CLASS = ytest ).sort_values('SCORE', ascending = False).head(20)
|128||1||color_intensity <= 3.26||proline <= 375.00||1|
|97||1||color_intensity <= 3.43||proline <= 505.00||1|
|86||1||color_intensity <= 3.22||12.16 < alcohol <= 12.21||1|
|117||0.99||color_intensity <= 3.43||proline <= 375.00||1|
|113||0.99||color_intensity <= 3.43||proline <= 476.00||1|
|106||0.99||color_intensity <= 3.43||alcohol <= 12.32||1|
|114||0.99||color_intensity <= 3.43||proline <= 476.00||1|
|78||0.98||color_intensity <= 3.43||alcohol <= 12.37||1|
|99||0.98||color_intensity <= 3.28||proline <= 476.00||1|
|94||0.98||color_intensity <= 3.28||alcohol <= 12.16||1|
|108||0.98||color_intensity <= 3.43||proline <= 430.00||1|
|105||0.96||color_intensity <= 3.26||proline <= 375.00||1|
|88||0.94||color_intensity <= 3.35||alcohol <= 12.16||1|
|109||0.93||color_intensity <= 3.22||alcohol <= 12.06||1|
|110||0.93||color_intensity <= 3.22||alcohol <= 12.16||1|
|102||0.93||color_intensity <= 3.26||proline <= 446.00||1|
|74||0.91||color_intensity <= 3.22||alcohol <= 12.06||1|
|124||0.91||color_intensity <= 3.41||proline <= 430.00||1|
|104||0.9||color_intensity <= 3.43||alcohol <= 12.65||1|
|129||0.89||color_intensity <= 3.43||alcohol <= 12.16||1|
From the class 2 model:
tree_explainer.predict_explain(clf, Xtest).assign( TRUE_CLASS = ytest ).sort_values('SCORE', ascending = False).head(20)
|176||1||hue <= 0.80||flavanoids <= 0.85||2|
|155||1||hue <= 0.80||od280/od315 <= 1.78||2|
|174||1||hue <= 0.80||od280/od315 <= 1.78||2|
|148||0.98||hue <= 0.80||od280/od315 <= 1.78||2|
|169||0.94||hue <= 0.76||od280/od315 <= 2.01||2|
|163||0.94||hue <= 0.70||flavanoids <= 0.89||2|
|171||0.91||hue <= 0.80||1.63 < od280/od315 <= 1.78||2|
|131||0.9||hue <= 0.76||od280/od315 <= 1.48||2|
|170||0.85||hue <= 0.70||flavanoids <= 0.89||2|
|136||0.81||hue <= 0.84||flavanoids <= 0.89||2|
|149||0.81||hue <= 0.76||od280/od315 <= 1.48||2|
|159||0.8||hue <= 0.76||1.78 < od280/od315 <= 1.79||2|
|150||0.79||hue <= 0.76||od280/od315 <= 1.48||2|
|151||0.77||od280/od315 <= 1.48||hue <= 0.76||2|
|154||0.75||hue <= 0.80||od280/od315 <= 1.63||2|
|140||0.72||hue <= 0.81||flavanoids <= 0.85||2|
|162||0.67||0.59 < flavanoids <= 0.89||malic_acid > 2.93||2|
|142||0.67||flavanoids <= 0.81||malic_acid > 2.93||2|
|141||0.66||hue <= 0.81||flavanoids <= 0.85||2|
|130||0.62||hue <= 0.81||od280/od315 <= 1.48||2|
For each model, the reasons produced tend to be quite similar to those given by a LIME explainer.
Final notes about tree interpretation
Given that it seems to generate explanations comparable to the LIME system, the most remarkable thing about the tree interpretation approach is its extreme speed – each of the sets of explanations just shown required around 250 milliseconds! This impressive speed is mostly attained by abandoning LIME’s model-agnosticism; tree interpretation only works for tree-based models.
Currently, my approach takes into account every split in every decision tree where any positive value was added to the data point’s final score. In large random forests comprised of many hundreds of trees, with splits that may contribute very small values to a score, some more sophisticated condition should probably be used to decide which splits to consider. In fact, I have already begun experimenting in this area, but the details are a bit too deep for this blog post.
Again, the most amazing thing about this tree-centric approach to prediction explanations is its speed compared to LIME. In tentative applications to the model I mentioned before, this approach currently requires around 20 milliseconds per explanation, making it roughly 50 times as fast as low-fidelity LIME.