-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[RFC] Add interpretability API as xgboost.interpret module functions #11947
Copy link
Copy link
Open
Labels
Description
Summary
We have new work underway on shapley values and other related interpretability concepts. This will add new functionality, however the current feature importance/shapley features are included in the predict API. I propose extending the python api with a module (xgboost.interpret) for interpretability, containing stateless functions exposing upcoming features.
These functions accept either a Booster or an sklearn-style XGB* model, plus DMatrix/array-like inputs, and return well-typed results (arrays and or light-weight result objects).
Motivation
- Minimize disruption to existing
Booster/ sklearn APIs while adding interpretability features. - Improve discoverability and documentation (module-level functions are easy to document and test).
- Allow incremental implementation: start as wrappers over existing
predict(pred_contribs=..., pred_interactions=...), then evolve internals (esp. top-k) without changing the public API.
Proposed public API
Add a new module:
xgboost/interpret.py
Functions (accept Booster | XGBModel and DMatrix | array-like | pandas):
shap_values(model, X,*, X_background=None, output_margin=False, iteration_range=None, approx=False, validate_features=True, feature_names=None, return_bias=False)shap_interactions(model, X, *, X_background=None, output_margin=False, iteration_range=None, approx=False, validate_features=True, feature_names=None)topk_interactions(model, X, *, X_background=None, k=50, metric="mean_abs", output_margin=False, iteration_range=None, validate_features=True, feature_names=None)- Note: possibly just fold this into shap_interactionspartial_dependence(model, X, *, features, grid_resolution=50, percentiles=(0.05,0.95), grid=None, sample_weights=None, random_state=0, output="prediction", iteration_range=None)- Possibly adding shap_values, shap_values methods to booster/sklearn class for convenience
Dispatch/behavior notes
- Internally normalize
modelto aBoosterviamodelbeingBoosteror havingget_booster(). - Normalize
XtoDMatrixif needed; respect feature names where possible. - Initial SHAP implementations can wrap existing
Booster.predict(..., pred_contribs=True/pred_interactions=True)for compatibility. topk_interactionsshould ideally avoid materializing full (n, p, p) tensors; target a C++ implementation to compute aggregated top-k pairs efficiently.
Return types
Prefer lightweight result objects to keep outputs consistent and extensible:
ShapValues(values, base_values, feature_names, model_output, ...)ShapInteractions(values, feature_names, ...)with helpers for main effects / pair extractionTopKInteractions(pairs, scores, pair_names=None, per_row=None, ...)PDP(features, grid_values, averages, ...)
Documentation plan (Sphinx)
- Add
docs/python/interpretability.rst- narrative examples + API reference using
.. autofunction::for each function .. autoclass::for result types
- narrative examples + API reference using
Reactions are currently unavailable