Visualizing variable importance using SHAP and cross-validation

Lucas Ramos
4 min readJun 26, 2020

We have all seen how amazing and insightful SHAP explanations can be. Though, when writing a scientific paper or any other situation that requires one to run multiple cross-validation iterations, the number of images start to stack and it becomes difficult to report results.

If you’re just looking for an answer you can skip this intro and go straight to the section below (named The Real Deal) where I explain the code and approach.

Colab code can be found here:

Machine learning models have been for long coined as a “black box”, due to their complex nature. The more complex and robust models become, the more difficult it becomes to obtain meaningful variable importance.

Of course, the field of model visualization and interpretation is as active as the AI field itself. We have all seen how powerful and helpful tools such as SHAP and LIME are at visualizing and explaining machine learning models for both image and tabular data.

LIME (Local Interpretable Model-Agnostic Explanations) became popular around 2016 when the paper “Why Should I Trust You?”: Explaining the Predictions of Any Classifier, was published. It showed particularly interesting cases, like the one below, where a husky was classified as a wolf due to the snow in the background (that’s how they found the network was actually learning the difference between dogs and wolves).

Some of the issues of explainers such as LIME, is that they often train a new model to explain the one you created, and the explanations are “sample-wise” which can be hard to condense and make any conclusions.

Example of sample-wise prediction from LIME
Example of sample-wise prediction from LIME

More recently, SHAP condensed knowledge from multiple explainers, including LIME, and created their own visualization toolkit. While LIME is “model -agnostic”, which means it can be used with any classifier, SHAP has different modules for each “family of classifiers”, like tree based model, linear models, etc.

Some improvements from SHAP over many others, are the following:

1- Global explanations for the model/dataset instead of just sample-wise;

2- Ways to evaluate local feature interactions

3- Explanation/visualizations are generated based on your trained model instead of training a new one.

An example of one of my favorite SHAP visualization plots is shown below. The summary plot shows not only variable importance (from top to bottom), but also shows how the feature values (using the color spectrum from blue to red) affect the label prediction. Here, high values of alcohol influence a positive label prediction(SHAP values above zero). More about this can be found HERE.

Shap summary plot

It is relatively easy to create such plot. Below we create this plot for the breast_cancer dataset.

from sklearn.model_selection import train_test_splitfrom sklearn.datasets import load_breast_cancerfrom sklearn.ensemble import RandomForestClassifierimport shapimport pandas as pd#loading and preparing the datairis = load_breast_cancer()X = iris.datay = iris.targetcolumns = iris.feature_namesX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)X_train = pd.DataFrame(X_train,columns=columns)X_test = pd.DataFrame(X_test,columns=columns)#training modelclf = RandomForestClassifier(random_state=0)clf.fit(X_train, y_train)#explaining modelexplainer = shap.TreeExplainer(clf)shap_values = explainer.shap_values(X_test)#shap_values of 1 for positive labelshap.summary_plot(shap_values[1], X_test)

Which results in:

The Real Deal

Now, that works fine for 1 iteration, but everyone knows it is good practice to run multiple cross-validation iterations when developing your models to reduce the risk of a luck train/test split.

To generate a plot to summarize all cross-validation iterations I suggest you use k-fold cross-validation. This way we are able to compute shap values for each test set fold (see code below).

After running the code for all folds, we get a plot with the summary of feature importance for the whole experiment.

Summary plot for a single fold on the left and for all folds together on the right.

Despite the results not changing a lot (I guess the conclusion regarding the values would still be the same), there’s a clear difference between plots, especially in the density and spread of values. We have also to keep in mind that this is a relatively easy dataset, and for more complex problems, reporting results for multiple iterations might actually change your final conclusions.

Finally, this might not be the best way to do it, so I’m curious and open for discussing your ideas to handle it.

--

--