SHAP (SHapley Additive exPlanations, see Lundberg and Lee (2017)) is an ingenious way to study black box models. SHAP values decompose - as fair as possible - predictions into additive feature contributions. Crunching SHAP values requires clever algorithms. Analyzing them, however, is super easy with the right visualizations. {shapviz} offers the latter.
In particular, the following plots are available:
sv_importance()
: Importance plot (bar/beeswarm).sv_dependence()
and sv_dependence2D()
:
Dependence plots to study feature effects and interactions.sv_interaction()
: Interaction plot (beeswarm).sv_waterfall()
: Waterfall plot to study single or
average predictions.sv_force()
: Force plot as alternative to waterfall
plot.SHAP and feature values are stored in a “shapviz” object that is built from:
We use {patchwork} to glue together multiple plots with (potentially) inconsistent x and/or color scale.
Shiny diamonds… let’s use XGBoost to model their prices by the four “C” variables.
library(shapviz)
library(ggplot2)
library(xgboost)
set.seed(1)
xvars <- c("log_carat", "cut", "color", "clarity")
X <- diamonds |>
transform(log_carat = log(carat)) |>
subset(select = xvars)
head(X)
#> log_carat cut color clarity
#> 1 -1.469676 Ideal E SI2
#> 2 -1.560648 Premium E SI1
#> 3 -1.469676 Good E VS1
#> 4 -1.237874 Premium I VS2
#> 5 -1.171183 Good J SI2
#> 6 -1.427116 Very Good J VVS2
# Fit (untuned) model
fit <- xgb.train(
params = list(learning_rate = 0.1, nthread = 1),
data = xgb.DMatrix(data.matrix(X), label = log(diamonds$price), nthread = 1),
nrounds = 65
)
# SHAP analysis: X can even contain factors
X_explain <- X[sample(nrow(X), 2000), ]
shp <- shapviz(fit, X_pred = data.matrix(X_explain), X = X_explain)
sv_importance(shp, show_numbers = TRUE)
We can visualize decompositions of single predictions via waterfall or force plots:
Also multiple row_id
can be passed: The SHAP values of
the selected rows are averaged and then plotted as aggregated
SHAP values: The prediction profile for beautiful color “D”
diamonds:
If SHAP interaction values have been computed (via {xgboost} or
{treeshap}), we can study them by sv_dependence()
and
sv_interaction()
.
Note that SHAP interaction values are multiplied by two (except main effects).
The above examples used XGBoost to calculate SHAP values. What about other packages?
If you work with a boosted trees H2O model:
library(shapviz)
library(treeshap)
library(ranger)
fit <- ranger(
y = iris$Sepal.Width, x = iris[-1], max.depth = 6, num.trees = 100
)
unified_model <- ranger.unify(fit, iris[-1])
shaps <- treeshap(unified_model, iris[-1], interactions = TRUE)
shp <- shapviz(shaps, X = iris)
sv_importance(shp)
sv_dependence(
shp, "Sepal.Width", color_var = names(iris[-1]), alpha = 0.7, interactions = TRUE
)
Decompositions of single predictions obtained by the breakdown algorithm in DALEX:
library(shapviz)
library(DALEX)
library(ranger)
fit <- ranger(Sepal.Length ~ ., data = iris, max.depth = 6, num.trees = 100)
explainer <- DALEX::explain(fit, data = iris[-1], y = iris[, 1], label = "RF")
bd <- explainer |>
predict_parts(iris[1, ], keep_distributions = FALSE) |>
shapviz()
sv_waterfall(bd)
sv_force(bd)
Either using kernelshap()
or
permshap()
:
The most general interface is to provide a matrix of SHAP values and corresponding feature values (and optionally, a baseline value):
S <- matrix(c(1, -1, -1, 1), ncol = 2, dimnames = list(NULL, c("x", "y")))
X <- data.frame(x = c("a", "b"), y = c(100, 10))
shp <- shapviz(S, X, baseline = 4)
An example is CatBoost: it is not on CRAN, and requires
catboost.*()
functions to calculate SHAP values, so we
cannot directly add it to {shapviz} for now. Use a wrapper like
this:
library(shapviz)
library(catboost)
shapviz.catboost.Model <- function(object, X_pred, X = X_pred, collapse = NULL, ...) {
if (!inherits(X_pred, "catboost.Pool")) {
X_pred <- catboost.load_pool(X_pred)
}
S <- catboost.get_feature_importance(object, X_pred, type = "ShapValues", ...)
pp <- ncol(X_pred) + 1
baseline <- S[1, pp]
S <- S[, -pp, drop = FALSE]
colnames(S) <- colnames(X_pred)
shapviz(S, X = X, baseline = baseline, collapse = collapse)
}
# Example
X_pool <- catboost.load_pool(iris[-1], label = iris[, 1])
params <- list(loss_function = "RMSE", iterations = 65, allow_writing_files = FALSE)
fit <- catboost.train(X_pool, params = params)
shp <- shapviz(fit, X_pred = X_pool, X = iris)
sv_importance(shp)
sv_dependence(shp, colnames(iris[-1]))