from sklearn import datasets
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn import metrics
from sklearn.inspection import permutation_importance
from sklearn.inspection import PartialDependenceDisplay
import shap
# Java Script for SHAP Plots
shap.initjs()
# Helper function to see methods in object
# Might be useful when working through this exercise
def object_methods(obj):
'''
Helper function to list methods associated with an object
'''
try:
methods = [method_name for method_name in dir(obj)
if callable(getattr(obj, method_name))]
print('Below are the methods for object: ', obj)
for method in methods:
print(method)
except:
print("Error")18 Exercise Solution: Explainable AI (Penguins Classification Dataset)
In this notebook, we’ll be exploring how to use a couple of different explainable AI techniques.
We’ll be using a different dataset to take a look into this today. The penguins dataset is a great dataset for practising classification problems. This data has been pulled using the excellent palmerpenguins package.
What we need to know is that
- The dataset is made up of 344 rows/ instances
- Each row has columns pertaining to sex, species, island on which they are found, bill length, bill depth, flipper length, and body mass.
- There are three species (classes) to consider: Adelie, Chinstrap and Gentoo.
- These are the targets.
In this exercise you will need go through the code and fill in any missing spaces.
By the end of this exercise you should know - how to calculate feature importance using the MDI method for tree-based models - how to calculate feature importance for any model using the permutation feature importance method - how to create partial dependence plots (PDPs) and individual conditionla expectation (ICE) plots for any model - how to use the SHAP library to understand a model
The SHAP code does vary subtly for different kinds of model; we will just be working with an XGboost model in this case to match what we’ve done in the lecture.
18.0.0.1 Library Imports
18.0.0.2 Load & Clean Data
Run this cell to load the dataframe.
penguins = pd.read_csv("../datasets/penguins.csv")Examine the dataset with your choice(s) of function(s).
penguins.head()| bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | island_Biscoe | island_Dream | island_Torgersen | male | target | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 32.1 | 15.5 | 188.0 | 3050.0 | 2009 | 0 | 1 | 0 | 0.0 | 0 |
| 1 | 33.1 | 16.1 | 178.0 | 2900.0 | 2008 | 0 | 1 | 0 | 0.0 | 0 |
| 2 | 33.5 | 19.0 | 190.0 | 3600.0 | 2008 | 0 | 0 | 1 | 0.0 | 0 |
| 3 | 34.0 | 17.1 | 185.0 | 3400.0 | 2008 | 0 | 1 | 0 | 0.0 | 0 |
| 4 | 34.1 | 18.1 | 193.0 | 3475.0 | 2007 | 0 | 0 | 1 | NaN | 0 |
Run the code below to convert the classes 0, 1 and 2 into the relevant species names and add this as a new column.
Try to understand how this is working - it’s a useful little pattern to know for your own datasets!
First, we are going to create a dictionary. Can you remember what we call the parts before and the parts after the colon in the dictionary?
# Define the different classes/ species
class_dict = {0 : 'Adelie',
1 : 'Chinstrap',
2 : 'Gentoo'}
class_dict{0: 'Adelie', 1: 'Chinstrap', 2: 'Gentoo'}
Now we are going to use our dictionary for creating the column.
How do you think this is working? You may want to look up the get method of the standard python dictionary, and the apply method of pandas, just to understand a little more about this very useful way of making new conditional columns.
# Add species into the dataframe
penguins['species'] = penguins['target'].apply(lambda x: class_dict.get(x))
# view a random sample of 10 rows
penguins.sample(10)| bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | island_Biscoe | island_Dream | island_Torgersen | male | target | species | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 197 | 50.7 | 19.7 | 203.0 | 4050.0 | 2009 | 0 | 1 | 0 | 1.0 | 1 | Chinstrap |
| 269 | 46.5 | 13.5 | 210.0 | 4550.0 | 2007 | 1 | 0 | 0 | 0.0 | 2 | Gentoo |
| 151 | NaN | NaN | NaN | NaN | 2007 | 0 | 0 | 1 | NaN | 0 | Adelie |
| 169 | 46.4 | 17.8 | 191.0 | 3700.0 | 2008 | 0 | 1 | 0 | 0.0 | 1 | Chinstrap |
| 12 | 35.2 | 15.9 | 186.0 | 3050.0 | 2009 | 0 | 0 | 1 | 0.0 | 0 | Adelie |
| 245 | 45.1 | 14.5 | 215.0 | 5000.0 | 2007 | 1 | 0 | 0 | 0.0 | 2 | Gentoo |
| 307 | 49.2 | 15.2 | 221.0 | 6300.0 | 2007 | 1 | 0 | 0 | 1.0 | 2 | Gentoo |
| 252 | 45.4 | 14.6 | 211.0 | 4800.0 | 2007 | 1 | 0 | 0 | 0.0 | 2 | Gentoo |
| 68 | 38.3 | 19.2 | 189.0 | 3950.0 | 2008 | 0 | 1 | 0 | 1.0 | 0 | Adelie |
| 207 | 51.5 | 18.7 | 187.0 | 3250.0 | 2009 | 0 | 1 | 0 | 1.0 | 1 | Chinstrap |
Let’s take a look at some stats about the data to get an idea of the scale and distribution of the different features. Run the cell below to do this.
# Take a look at some stats about the data
penguins.describe()| bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | year | island_Biscoe | island_Dream | island_Torgersen | male | target | |
|---|---|---|---|---|---|---|---|---|---|---|
| count | 342.000000 | 342.000000 | 342.000000 | 342.000000 | 344.000000 | 344.000000 | 344.000000 | 344.000000 | 333.000000 | 344.000000 |
| mean | 43.921930 | 17.151170 | 200.915205 | 4201.754386 | 2008.029070 | 0.488372 | 0.360465 | 0.151163 | 0.504505 | 0.918605 |
| std | 5.459584 | 1.974793 | 14.061714 | 801.954536 | 0.818356 | 0.500593 | 0.480835 | 0.358729 | 0.500732 | 0.893320 |
| min | 32.100000 | 13.100000 | 172.000000 | 2700.000000 | 2007.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 25% | 39.225000 | 15.600000 | 190.000000 | 3550.000000 | 2007.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 50% | 44.450000 | 17.300000 | 197.000000 | 4050.000000 | 2008.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 |
| 75% | 48.500000 | 18.700000 | 213.000000 | 4750.000000 | 2009.000000 | 1.000000 | 1.000000 | 0.000000 | 1.000000 | 2.000000 |
| max | 59.600000 | 21.500000 | 231.000000 | 6300.000000 | 2009.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 2.000000 |
18.0.0.3 Plot the Data
Before we go any further, let’s plot the iris dataset to see how the petal width and length relate to the species.
Fill in the gaps below to create the plot
adelie = penguins[penguins.species=='Adelie']
chinstrap = penguins[penguins.species == "Chinstrap"]
gentoo = penguins[penguins.species=='Gentoo']
fig, ax = plt.subplots()
fig.set_size_inches(13, 7) # adjusting the length and width of plot
# lables and scatter points
ax.scatter(adelie['bill_length_mm'], adelie['bill_depth_mm'], label="Adelie", facecolor="blue")
ax.scatter(chinstrap['bill_length_mm'], chinstrap['bill_depth_mm'], label="Chinstrap", facecolor="green")
ax.scatter(gentoo['bill_length_mm'], gentoo['bill_depth_mm'], label="Gentoo", facecolor="red")
ax.set_xlabel("Bill Length (mm)")
ax.set_ylabel("Bill Depth (mm)")
ax.grid()
ax.set_title("Gentoo Penguin Measurements")
ax.legend()
Now it’s your turn; create this plot, but this time we are interested in the other two columns: sepal length and sepal width.
In the space below, make a copy of the plot that looks at flipper length and body mass.
adelie = penguins[penguins.species=='Adelie']
chinstrap = penguins[penguins.species == "Chinstrap"]
gentoo = penguins[penguins.species=='Gentoo']
fig, ax = plt.subplots()
fig.set_size_inches(13, 7) # adjusting the length and width of plot
# lables and scatter points
ax.scatter(adelie['flipper_length_mm'], adelie['body_mass_g'], label="Adelie", facecolor="blue")
ax.scatter(chinstrap['flipper_length_mm'], chinstrap['body_mass_g'], label="Chinstrap", facecolor="green")
ax.scatter(gentoo['flipper_length_mm'], gentoo['body_mass_g'], label="Gentoo", facecolor="red")
ax.set_xlabel("Flipper Length (mm)")
ax.set_ylabel("Body Mass (g)")
ax.grid()
ax.set_title("Gentoo Penguin Measurements")
ax.legend()
18.0.0.4 Performing Classification
Time to get on to the machine learning aspect.
It’s always important to spend some time understanding your dataset first, though! What we’ve done above is just the tip of the iceberg, but it’s a good start.
Now we’re going to process our dataset for machine learning as we did in the logistic regression, decision tree and boosted tree sessions - this time we’re going to be using XGBoost.
# Droping the target and species since we only need the measurements
X = penguins.drop(['target','species'], axis=1)
# Define features (X) and target (y)
X = X
y = penguins['target']
# get class and features names
class_names = penguins.species.unique()
feature_names = X.columns
# Splitting into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42)
# Instantiate an XGBoost model and fit it
model = XGBClassifier(random_state=42)
model.fit(X_train, y_train)XGBClassifier(base_score=None, booster=None, callbacks=None,
colsample_bylevel=None, colsample_bynode=None,
colsample_bytree=None, device=None, early_stopping_rounds=None,
enable_categorical=False, eval_metric=None, feature_types=None,
gamma=None, grow_policy=None, importance_type=None,
interaction_constraints=None, learning_rate=None, max_bin=None,
max_cat_threshold=None, max_cat_to_onehot=None,
max_delta_step=None, max_depth=None, max_leaves=None,
min_child_weight=None, missing=nan, monotone_constraints=None,
multi_strategy=None, n_estimators=None, n_jobs=None,
num_parallel_tree=None, objective='multi:softprob', ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
XGBClassifier(base_score=None, booster=None, callbacks=None,
colsample_bylevel=None, colsample_bynode=None,
colsample_bytree=None, device=None, early_stopping_rounds=None,
enable_categorical=False, eval_metric=None, feature_types=None,
gamma=None, grow_policy=None, importance_type=None,
interaction_constraints=None, learning_rate=None, max_bin=None,
max_cat_threshold=None, max_cat_to_onehot=None,
max_delta_step=None, max_depth=None, max_leaves=None,
min_child_weight=None, missing=nan, monotone_constraints=None,
multi_strategy=None, n_estimators=None, n_jobs=None,
num_parallel_tree=None, objective='multi:softprob', ...)18.0.1 Feature Importance
18.0.1.0.1 Look at feature importance using the feature_importances_ attribute
# Looking at standard feature importance
# This attribute is the mean decrease in impurity for each feature
importances = model.feature_importances_
importancesarray([0.10309331, 0.01894252, 0.5356301 , 0.00305671, 0.00406754,
0.07005206, 0.1871977 , 0.07796 , 0. ], dtype=float32)
18.0.1.1 Mean decrease in impurity
Generate a plot of the MDI feature importances.
feature_names = X.columns.tolist()
model_importances_mdi_series = pd.Series(importances, index=feature_names)
fig, ax = plt.subplots(figsize=(15,10))
model_importances_mdi_series.plot.bar(ax=ax)
ax.set_title("Feature importances using MDI")
ax.set_ylabel("Mean decrease in impurity")
fig.tight_layout()
18.0.2 Permutation Feature Importance
Calculate permutation feature importance for this dataset and plot it.
This will include error bars.
# Permutation feature importance
result = permutation_importance(
model, X_test, y_test, n_repeats=10, random_state=42)
model_importances_pfi_series = pd.Series(result.importances_mean, index=feature_names)
fig, ax = plt.subplots(figsize=(15,10))
model_importances_pfi_series.plot.bar(yerr=result.importances_std, ax=ax)
ax.set_title("Feature importances using permutation on full model")
ax.set_ylabel("Mean accuracy decrease")
fig.tight_layout()
plt.show()
18.0.2.1 Predictions
Use the model to make predictions for the training and test set
# Training predictions
y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)18.0.2.2 Assessing Performance
Run this code to generate metrics for the training and test performance of this model.
accuracy_train = np.mean(y_pred_train == y_train)
accuracy_test = np.mean(y_pred_test == y_test)
print (f'Accuracy of predicting training data = {accuracy_train:.2%}')
print (f'Accuracy of predicting test data = {accuracy_test:.2%}')Accuracy of predicting training data = 100.00%
Accuracy of predicting test data = 97.10%
18.1 PDP Plots
Now let’s create a partial dependence plot for flipper length.
fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
model, # Your fitted model
X, # Your feature matrix
features=['flipper_length_mm'], # List of features to plot
target=0,
kind='average', # Type of PDP
ax=ax,
random_state=42
)
plt.show()
Now create two plots side-by-side for bill length and bill depth.
HINT: You don’t need to create multiple separate plots using matplotlib for this - you can do it from within the graphing function we’re using from scikit-learn.
fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
model, # Your fitted model
X, # Your feature matrix
features=['bill_length_mm', 'bill_depth_mm'], # List of features to plot
target=0,
kind='average', # Type of PDP
ax=ax,
random_state=42
)
plt.show()
18.2 ICE Plots
Now create three ICE plots of the same feature - one for each class. Make sure to give each plot a name.
fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
model, # Your fitted model
X, # Your feature matrix
features=['bill_length_mm'], # List of features to plot
target=0,
kind='individual', # Type of PDP
ax=ax,
random_state=42
)
plt.title("Adelie Penguins - Bill Length ICE Plot")
plt.show()
fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
model, # Your fitted model
X, # Your feature matrix
features=['bill_length_mm'], # List of features to plot
target=1,
kind='individual', # Type of PDP
ax=ax,
random_state=42
)
plt.title("Chinstrap Penguins - Bill Length ICE Plot")
plt.show()
fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
model, # Your fitted model
X, # Your feature matrix
features=['bill_length_mm'], # List of features to plot
target=2,
kind='individual', # Type of PDP
ax=ax,
random_state=42
)
plt.title("Gentoo Penguins - Bill Length ICE Plot")
plt.show()
Now, just for one of the classes, create an ICE plot for bill_length_mm that also shows the average of all the ICE plots - a joint PDP/ICE plot, effectively!
Again, make sure you provide a title.
fig, ax = plt.subplots(figsize=(10, 6))
display = PartialDependenceDisplay.from_estimator(
model, # Your fitted model
X, # Your feature matrix
features=['bill_length_mm'], # List of features to plot
target=2,
kind='both', # Type of PDP
ax=ax,
random_state=42
)
plt.title("Adelie Penguins - Joint ICE/PDP Plot")
plt.show()
18.2.0.1 BONUS: 2D PDP Plots
Now create a 2D plot of bill length and bill depth.
PartialDependenceDisplay.from_estimator(
model,
X_test,
features=[('bill_length_mm', 'bill_depth_mm')],
kind='average',
target=0,
random_state=0
)
18.3 SHAP
We have a multiclass problem with our penguins dataset.
This results in some slightly different outputs from our SHAP code, which can be confusing to deal with, so for now we’re just going to focus on a binary classification problem - is a penguin an Adelie, or not?
Run the code below to turn this into a binary classification problem and retrain the model.
penguins_binary = penguins.copy()
# If Adelie penguin, return 1, else return 0
penguins_binary['target'] = np.where(penguins_binary['target'] == 0, 1, 0)
penguins_binary['species'] = np.where(penguins_binary['species'] == "Adelie", "Adelie", "Not Adelie")
# Droping the target and species since we only need the measurements
X = penguins_binary.drop(['target','species'], axis=1)
# Define features (X) and target (y)
X = X
y = penguins_binary['target']
# get class and features names
class_names = penguins_binary.species.unique()
feature_names = X.columns
# Splitting into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42)
# Instantiate an XGBoost model and fit it
model = XGBClassifier(random_state=42)
model.fit(X_train, y_train)XGBClassifier(base_score=None, booster=None, callbacks=None,
colsample_bylevel=None, colsample_bynode=None,
colsample_bytree=None, device=None, early_stopping_rounds=None,
enable_categorical=False, eval_metric=None, feature_types=None,
gamma=None, grow_policy=None, importance_type=None,
interaction_constraints=None, learning_rate=None, max_bin=None,
max_cat_threshold=None, max_cat_to_onehot=None,
max_delta_step=None, max_depth=None, max_leaves=None,
min_child_weight=None, missing=nan, monotone_constraints=None,
multi_strategy=None, n_estimators=None, n_jobs=None,
num_parallel_tree=None, random_state=42, ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
XGBClassifier(base_score=None, booster=None, callbacks=None,
colsample_bylevel=None, colsample_bynode=None,
colsample_bytree=None, device=None, early_stopping_rounds=None,
enable_categorical=False, eval_metric=None, feature_types=None,
gamma=None, grow_policy=None, importance_type=None,
interaction_constraints=None, learning_rate=None, max_bin=None,
max_cat_threshold=None, max_cat_to_onehot=None,
max_delta_step=None, max_depth=None, max_leaves=None,
min_child_weight=None, missing=nan, monotone_constraints=None,
multi_strategy=None, n_estimators=None, n_jobs=None,
num_parallel_tree=None, random_state=42, ...)18.3.1 Obtaining Shap Values
Generate a SHAP explainer for our model, using X_train as the background data.
# Compute SHAP values
explainer = shap.Explainer(
model,
X_train
)Now create the shap_values object, using X_test as the foreground data.
shap_values = explainer.shap_values(X_test)
shap_valuesarray([[-5.17051824e+00, 1.21827941e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, -1.36875771e-01],
[ 1.13075847e-01, 1.78565094e+00, -4.36455202e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.35543021e-01, -1.91024542e-01],
[-2.72566813e-02, -3.30757442e+00, -2.22220667e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-5.10160202e+00, 1.83535931e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[-4.14983763e+00, -1.87998490e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 5.45284030e+00, 1.74075824e+00, -3.02373941e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 1.91534915e+00, -3.58559378e+00, -3.02373941e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -2.18033250e-01, -1.91024542e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.68026578e+00, 1.44937979e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.42733446e+00, -1.30585920e-01],
[ 6.04219020e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, 1.08201245e-01],
[ 4.79211754e+00, -1.02131687e-03, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 6.44204529e+00, -5.33143660e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.01846663e-01, -1.30585920e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -8.35388462e-02, 1.08201245e-01],
[-4.44524486e-02, -3.30757442e+00, -2.22220667e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-4.04089575e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.97731472e+00, 1.03570520e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[ 3.33277721e+00, 2.57971638e+00, 1.75061948e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.46053878e-01, 1.59685997e-01],
[ 5.91554046e+00, 1.47107445e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.02227896e-02, -1.30585920e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.51627192e-02, -1.30585920e-01],
[ 2.60384000e+00, -4.47425562e-01, 9.22989947e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.50891294e-01, -1.91024542e-01],
[-5.13327016e+00, -1.40029112e-01, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, 1.13559266e-01],
[ 5.68476218e+00, 1.66361659e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.02227896e-02, 1.08201245e-01],
[-5.16359000e+00, 9.08194313e-01, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-3.45011731e+00, 2.42737406e+00, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.18269527e-01, 1.59685997e-01],
[-3.79519642e+00, -2.13501758e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.71910748e-01, -1.91024542e-01],
[ 6.07069525e+00, 1.30955449e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.69047685e-02, 1.08201245e-01],
[ 4.36485172e+00, 1.75760771e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.06125901e+00, 4.69742856e-02, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.85268838e-01, -1.36875771e-01],
[ 2.94871511e+00, 1.16633540e+00, 1.27291407e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 2.49448815e+00, -1.91024542e-01],
[ 5.94404552e+00, 1.47107445e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, -1.30585920e-01],
[ 3.45756491e+00, 6.67898319e-01, 1.75061948e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.16614424e-01, 1.59685997e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.94008613e+00, 1.97020284e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.76751433e+00, 1.38394169e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, -1.30585920e-01],
[-5.19412194e+00, 9.21530487e-01, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-4.41536294e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[-4.81358717e+00, -8.71694393e-01, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[-4.81358717e+00, -1.02284420e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[ 4.60306368e+00, 1.51015618e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, 1.08201245e-01],
[ 6.44204529e+00, -3.81993851e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -1.01846663e-01, -1.30585920e-01],
[ 4.78287797e+00, -2.68434791e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-7.61286116e-01, -3.12572213e+00, -1.47042132e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[ 4.49503812e+00, 1.58729784e+00, 1.60711127e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[ 4.49106620e+00, 1.63139322e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.11681622e+00, 1.45285302e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.96807515e+00, 1.03570520e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-3.27354855e+00, -2.53607305e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.71910748e-01, -1.91024542e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.08042636e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, -1.30585920e-01],
[ 4.93718394e+00, -3.09827117e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, 1.08201245e-01],
[-5.17991076e+00, 1.21827941e+00, -5.74065812e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 3.65377114e+00, 2.45184202e+00, 9.22989947e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.42782735e-01, 1.59685997e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[-4.53595532e+00, -1.58787129e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.84721798e-01, -1.36875771e-01],
[-3.45512069e+00, -2.34506634e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.81345309e-01, -1.91024542e-01],
[-5.12608560e+00, 9.08194313e-01, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.46807246e-01, -1.36875771e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[ 4.57339775e+00, 1.58729784e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, -1.30585920e-01],
[ 5.53214335e+00, 1.84474047e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 6.66092991e+00, 7.65813945e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01]])
It looks like it’s returned our outputs just as an array instead of a SHAP explanation object. Run the code below to turn our object into a proper shap.Explanation() object, as this is what all the plotting functions will be expecting.
# Create an Explanation object
shap_values = shap.Explanation(
values=shap_values,
base_values=explainer.expected_value,
data=X_test.values,
feature_names=X.columns
)Now let’s see what this looks like instead.
shap_values.values =
array([[-5.17051824e+00, 1.21827941e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, -1.36875771e-01],
[ 1.13075847e-01, 1.78565094e+00, -4.36455202e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.35543021e-01, -1.91024542e-01],
[-2.72566813e-02, -3.30757442e+00, -2.22220667e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-5.10160202e+00, 1.83535931e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[-4.14983763e+00, -1.87998490e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 5.45284030e+00, 1.74075824e+00, -3.02373941e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 1.91534915e+00, -3.58559378e+00, -3.02373941e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -2.18033250e-01, -1.91024542e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.68026578e+00, 1.44937979e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.42733446e+00, -1.30585920e-01],
[ 6.04219020e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, 1.08201245e-01],
[ 4.79211754e+00, -1.02131687e-03, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 6.44204529e+00, -5.33143660e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.01846663e-01, -1.30585920e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -8.35388462e-02, 1.08201245e-01],
[-4.44524486e-02, -3.30757442e+00, -2.22220667e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-4.04089575e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.97731472e+00, 1.03570520e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[ 3.33277721e+00, 2.57971638e+00, 1.75061948e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.46053878e-01, 1.59685997e-01],
[ 5.91554046e+00, 1.47107445e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.02227896e-02, -1.30585920e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.51627192e-02, -1.30585920e-01],
[ 2.60384000e+00, -4.47425562e-01, 9.22989947e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.50891294e-01, -1.91024542e-01],
[-5.13327016e+00, -1.40029112e-01, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, 1.13559266e-01],
[ 5.68476218e+00, 1.66361659e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.02227896e-02, 1.08201245e-01],
[-5.16359000e+00, 9.08194313e-01, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-3.45011731e+00, 2.42737406e+00, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.18269527e-01, 1.59685997e-01],
[-3.79519642e+00, -2.13501758e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.71910748e-01, -1.91024542e-01],
[ 6.07069525e+00, 1.30955449e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.69047685e-02, 1.08201245e-01],
[ 4.36485172e+00, 1.75760771e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.06125901e+00, 4.69742856e-02, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.85268838e-01, -1.36875771e-01],
[ 2.94871511e+00, 1.16633540e+00, 1.27291407e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 2.49448815e+00, -1.91024542e-01],
[ 5.94404552e+00, 1.47107445e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, -1.30585920e-01],
[ 3.45756491e+00, 6.67898319e-01, 1.75061948e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.16614424e-01, 1.59685997e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.94008613e+00, 1.97020284e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.76751433e+00, 1.38394169e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, -1.30585920e-01],
[-5.19412194e+00, 9.21530487e-01, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-4.41536294e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[-4.81358717e+00, -8.71694393e-01, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[-4.81358717e+00, -1.02284420e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[ 4.60306368e+00, 1.51015618e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, 1.08201245e-01],
[ 6.44204529e+00, -3.81993851e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -1.01846663e-01, -1.30585920e-01],
[ 4.78287797e+00, -2.68434791e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-7.61286116e-01, -3.12572213e+00, -1.47042132e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[ 4.49503812e+00, 1.58729784e+00, 1.60711127e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[ 4.49106620e+00, 1.63139322e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.11681622e+00, 1.45285302e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.96807515e+00, 1.03570520e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-3.27354855e+00, -2.53607305e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.71910748e-01, -1.91024542e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.08042636e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, -1.30585920e-01],
[ 4.93718394e+00, -3.09827117e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, 1.08201245e-01],
[-5.17991076e+00, 1.21827941e+00, -5.74065812e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 3.65377114e+00, 2.45184202e+00, 9.22989947e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.42782735e-01, 1.59685997e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[-4.53595532e+00, -1.58787129e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.84721798e-01, -1.36875771e-01],
[-3.45512069e+00, -2.34506634e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.81345309e-01, -1.91024542e-01],
[-5.12608560e+00, 9.08194313e-01, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.46807246e-01, -1.36875771e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[ 4.57339775e+00, 1.58729784e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, -1.30585920e-01],
[ 5.53214335e+00, 1.84474047e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 6.66092991e+00, 7.65813945e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01]])
.base_values =
-0.7156927563131259
.data =
array([[ 32.1, 15.5, 188. , ..., 1. , 0. , 0. ],
[ 33.1, 16.1, 178. , ..., 1. , 0. , 0. ],
[ 33.5, 19. , 190. , ..., 0. , 1. , 0. ],
...,
[ 55.9, 17. , 228. , ..., 0. , 0. , 1. ],
[ 59.6, 17. , 230. , ..., 0. , 0. , 1. ],
[ nan, nan, nan, ..., 0. , 0. , nan]])
Finally, let’s grab just the numeric component (our actual shap values).
shap_values_numeric = shap_values.values
shap_values_numericarray([[-5.17051824e+00, 1.21827941e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, -1.36875771e-01],
[ 1.13075847e-01, 1.78565094e+00, -4.36455202e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.35543021e-01, -1.91024542e-01],
[-2.72566813e-02, -3.30757442e+00, -2.22220667e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-5.10160202e+00, 1.83535931e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[-4.14983763e+00, -1.87998490e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 5.45284030e+00, 1.74075824e+00, -3.02373941e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 1.91534915e+00, -3.58559378e+00, -3.02373941e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -2.18033250e-01, -1.91024542e-01],
[-4.44195123e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.68026578e+00, 1.44937979e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.42733446e+00, -1.30585920e-01],
[ 6.04219020e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, 1.08201245e-01],
[ 4.79211754e+00, -1.02131687e-03, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 6.44204529e+00, -5.33143660e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.01846663e-01, -1.30585920e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -8.35388462e-02, 1.08201245e-01],
[-4.44524486e-02, -3.30757442e+00, -2.22220667e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[-4.04089575e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[ 4.97731472e+00, 1.03570520e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.54396007e+00, -1.30585920e-01],
[ 3.33277721e+00, 2.57971638e+00, 1.75061948e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.46053878e-01, 1.59685997e-01],
[ 5.91554046e+00, 1.47107445e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.02227896e-02, -1.30585920e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.51627192e-02, -1.30585920e-01],
[ 2.60384000e+00, -4.47425562e-01, 9.22989947e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.50891294e-01, -1.91024542e-01],
[-5.13327016e+00, -1.40029112e-01, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, 1.13559266e-01],
[ 5.68476218e+00, 1.66361659e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -9.02227896e-02, 1.08201245e-01],
[-5.16359000e+00, 9.08194313e-01, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-3.45011731e+00, 2.42737406e+00, 7.00742684e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.18269527e-01, 1.59685997e-01],
[-3.79519642e+00, -2.13501758e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.71910748e-01, -1.91024542e-01],
[ 6.07069525e+00, 1.30955449e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.69047685e-02, 1.08201245e-01],
[ 4.36485172e+00, 1.75760771e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.06125901e+00, 4.69742856e-02, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.85268838e-01, -1.36875771e-01],
[ 2.94871511e+00, 1.16633540e+00, 1.27291407e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 2.49448815e+00, -1.91024542e-01],
[ 5.94404552e+00, 1.47107445e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, -1.30585920e-01],
[ 3.45756491e+00, 6.67898319e-01, 1.75061948e-01,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -2.16614424e-01, 1.59685997e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-4.94008613e+00, 1.97020284e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.76751433e+00, 1.38394169e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, -1.30585920e-01],
[-5.19412194e+00, 9.21530487e-01, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.44841380e-01, 1.13559266e-01],
[-4.41536294e+00, -1.58787129e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[-4.81358717e+00, -8.71694393e-01, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[-4.81358717e+00, -1.02284420e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.83302972e-01, 1.13559266e-01],
[ 4.60306368e+00, 1.51015618e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.43894375e+00, 1.08201245e-01],
[ 6.44204529e+00, -3.81993851e-01, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -1.01846663e-01, -1.30585920e-01],
[ 4.78287797e+00, -2.68434791e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-7.61286116e-01, -3.12572213e+00, -1.47042132e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.31137460e-01, -1.91024542e-01],
[ 4.49503812e+00, 1.58729784e+00, 1.60711127e-01,
1.54761915e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[ 4.49106620e+00, 1.63139322e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, 1.08201245e-01],
[-5.11681622e+00, 1.45285302e+00, -5.74065812e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 4.96807515e+00, 1.03570520e+00, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, -1.30585920e-01],
[-3.27354855e+00, -2.53607305e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.71910748e-01, -1.91024542e-01],
[-4.01430746e+00, -1.98892677e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.08042636e+00, 1.30955449e+00, 9.22989947e-02,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.35887119e-02, -1.30585920e-01],
[ 4.93718394e+00, -3.09827117e-01, 1.27291407e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.58661936e+00, 1.08201245e-01],
[-5.17991076e+00, 1.21827941e+00, -5.74065812e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.35151969e-01, 1.13559266e-01],
[ 3.65377114e+00, 2.45184202e+00, 9.22989947e-02,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -1.42782735e-01, 1.59685997e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[-4.53595532e+00, -1.58787129e+00, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.84721798e-01, -1.36875771e-01],
[-3.45512069e+00, -2.34506634e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.81345309e-01, -1.91024542e-01],
[-5.12608560e+00, 9.08194313e-01, 4.39282524e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.46807246e-01, -1.36875771e-01],
[ 5.71326724e+00, 1.66361659e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[-4.10589588e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, -1.36875771e-01],
[ 6.58941987e+00, 8.37323978e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01],
[-4.13248416e+00, -1.89733836e+00, -1.60803193e-01,
4.70017666e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, -3.84721798e-01, 1.13559266e-01],
[-5.07603471e+00, 3.36381122e-02, 7.00742684e-02,
-5.04409203e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -3.83302972e-01, -1.36875771e-01],
[ 4.57339775e+00, 1.58729784e+00, 1.60711127e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
2.77963270e-01, 1.39628446e+00, -1.30585920e-01],
[ 5.53214335e+00, 1.84474047e+00, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -8.35388462e-02, 1.08201245e-01],
[ 6.66092991e+00, 7.65813945e-01, 1.75061948e-01,
-2.46472679e-02, 0.00000000e+00, 0.00000000e+00,
-3.68462939e-01, -9.51627192e-02, -1.30585920e-01]])
18.3.2 Exploring the SHAP outputs
First, let’s just get a list of our most important features according to SHAP.
# get feature importance for comparison using MDI method
features = list(X_train)
feature_importances = model.feature_importances_
importances = pd.DataFrame(index=features)
importances['importance'] = feature_importances
importances['rank'] = importances['importance'].rank(ascending=False).values
importances.sort_values('rank').head()
# Get shapley importances
# Calculate mean Shapley value for each feature in trainign set
importances['mean_shapley_values'] = np.mean(
shap_values_numeric, axis=0
)
# Calculate mean absolute Shapley value for each feature in trainign set
# This will give us the average importance of each feature
importances['mean_abs_shapley_values'] = np.mean(
np.abs(shap_values_numeric), axis=0
)
importance_top = \
importances.sort_values(
by='importance', ascending=False
).index
shapley_top = \
importances.sort_values(
by='mean_abs_shapley_values',
ascending=False).index
# Add to DataFrame
top_features = pd.DataFrame()
top_features['importances'] = importance_top.values
top_features['Shapley'] = shapley_top.values
# Display
top_features| importances | Shapley | |
|---|---|---|
| 0 | bill_length_mm | bill_length_mm |
| 1 | island_Torgersen | bill_depth_mm |
| 2 | bill_depth_mm | island_Torgersen |
| 3 | flipper_length_mm | island_Dream |
| 4 | male | flipper_length_mm |
| 5 | island_Dream | male |
| 6 | body_mass_g | body_mass_g |
| 7 | year | year |
| 8 | island_Biscoe | island_Biscoe |
18.3.2.1 SHAP plots
Generate a bar plot of the SHAP values.
shap.plots.bar(shap_values)
Generate a beeswarm plot.
shap.plots.beeswarm(shap_values)DimensionError: Feature and SHAP matrices must have the same number of rows!

Generate a waterfall plot for 5 different examples from the dataset.
shap.plots.waterfall(shap_values[0])shap.plots.waterfall(shap_values[3])shap.plots.waterfall(shap_values[5])shap.plots.waterfall(shap_values[-1])shap.plots.waterfall(shap_values[194])18.3.2.2 Dependence Plots for each Class (Species)
Let’s look at the columns in our dataset and the indices.
# Lets see the features and respective index numbers
for e, i in enumerate(X_test.columns):
print(f"{e} - {i}")First, generate a scatter plot for the bill length.
shap.plots.scatter(shap_values[:, 'bill_length_mm'])Now colour this by bill depth.
shap.plots.scatter(shap_values[:, 'bill_length_mm'], color=shap_values[:,"bill_depth_mm"])Now colour it by the most strongly interacting feature.
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
shap.plots.scatter(shap_values[:, "bill_length_mm"], color=shap_values)Now let’s iterate through and create scatter plots per column.
# dependence plots
fig, ax = plt.subplots(3, 3, figsize=(20,10))
ax = ax.ravel()
for idx, col_name in enumerate(feature_names):
shap.plots.scatter(shap_values[:, col_name], show=False, ax=ax[idx])18.3.2.3 Force Plots
Create a force plot for the whole dataset.
shap.plots.force(shap_values)Create a force plot for five randomly chosen pieces of data.
shap.plots.force(shap_values[0])shap.plots.force(shap_values[1])shap.plots.force(shap_values[-1])shap.plots.force(shap_values[185])shap.plots.force(shap_values[247])