The goal of this vignette is to show how the stochastic Shapley values produced from shapFlex
are correlated with the non-stochastic Shapley values computed in the Python shap package using the implentation discussed here.
Treating the tree-based Shapley values from shap
as an approximation of the ground truth, we would like to see if the sampling based method in shapFlex
can reproduce these values within sampling variability.
We’ll use catboost’s R
package which has a port of shap
found in the catboost.get_feature_importance()
function.
While shap
should be the preferred method when modeling with boosted trees, this vignette demonstrates that the more generic shapFlex
implementation can be applied to all classes of ML prediction models, boosted tree models included.
cat_features <- which(unlist(Map(is.factor, data[, -outcome_col]))) - 1
data_pool <- catboost.load_pool(data = data[, -outcome_col],
label = as.vector(as.numeric(data[, outcome_col])) - 1,
cat_features = cat_features)
set.seed(224)
model_catboost <- catboost.train(data_pool, NULL,
params = list(loss_function = 'CrossEntropy',
iterations = 30, logging_level = "Silent"))
For shapFlex
, the required user-defined prediction function takes 2 positional arguments and returns a 1-column data.frame
.
Note the creation of the catboost
-specific data format inside this function.
explain <- data[1:300, -outcome_col] # Compute Shapley feature-level predictions for 300 instances.
reference <- data[, -outcome_col] # An optional reference population to compute the baseline prediction.
sample_size <- 100 # Number of Monte Carlo samples.
set.seed(224)
data_shap <- shapFlex::shapFlex(explain = explain,
reference = reference,
model = model_catboost,
predict_function = predict_function,
sample_size = sample_size)
data_pool <- catboost.load_pool(data = data[1:300, -outcome_col],
label = as.vector(as.numeric(data[1:300, outcome_col])) - 1,
cat_features = cat_features)
data_shap_catboost <- catboost.get_feature_importance(model_catboost, pool = data_pool,
type = "ShapValues")
data_shap_catboost <- data.frame(data_shap_catboost[, -ncol(data_shap_catboost)]) # Remove the intercept column.
data_shap_catboost$index <- 1:nrow(data_shap_catboost)
data_shap_catboost <- tidyr::gather(data_shap_catboost, key = "feature_name",
value = "shap_effect_catboost", -index)
data_cor <- data_all %>%
dplyr::group_by(feature_name) %>%
dplyr::summarise("cor_coef" = round(cor(shap_effect, shap_effect_catboost), 3))
data_cor
## # A tibble: 13 x 2
## feature_name cor_coef
## <chr> <dbl>
## 1 age 0.997
## 2 capital_gain 0.999
## 3 capital_loss 0.986
## 4 education 0.992
## 5 education_num 0.998
## 6 hours_per_week 0.995
## 7 marital_status 0.993
## 8 native_country 0.997
## 9 occupation 0.997
## 10 race 0.997
## 11 relationship 0.991
## 12 sex 0.967
## 13 workclass 0.982
p <- ggplot(data_all, aes(shap_effect_catboost, shap_effect))
p <- p + geom_point(alpha = .25)
p <- p + geom_abline(color = "red")
p <- p + facet_wrap(~ feature_name, scales = "free")
p <- p + theme_bw() + xlab("catboost tree-based Shapley values") + ylab("shapFlex stochastic Shapley values") +
theme(axis.title = element_text(face = "bold"))
p