shap ¶
SHAP value computation and visualization for TabPFN.
This module provides functions to calculate and visualize SHAP (SHapley Additive exPlanations) values for TabPFN models. SHAP values help understand model predictions by attributing the contribution of each input feature to the output prediction.
Key features: - Efficient parallel computation of SHAP values - Support for both TabPFN and TabPFN-client backends - Specialized explainers for TabPFN models - Visualization functions for feature importance and interactions - Backend-specific optimizations for faster SHAP computation
Example usage
from tabpfn import TabPFNClassifier
from tabpfn_extensions.interpretability import get_shap_values, plot_shap
# Train a TabPFN model
model = TabPFNClassifier()
model.fit(X_train, y_train)
# Calculate SHAP values
shap_values = get_shap_values(model, X_test)
# Visualize feature importance
plot_shap(shap_values)
calculate_shap_subset ¶
Calculate SHAP values for a specific feature in a parallel context.
This helper function is used by parallel_permutation_shap to enable efficient parallel computation of SHAP values for each feature.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
args |
tuple
|
A tuple containing: - X_subset: Feature matrix for which to calculate SHAP values - background: Background data for the explainer - model: The model for which to calculate SHAP values - feature_idx: The index of the feature to calculate SHAP values for |
required |
Returns:
Type | Description |
---|---|
ndarray
|
np.ndarray: SHAP values for the specified feature |
get_default_explainer ¶
get_default_explainer(
estimator: Any,
test_x: DataFrame,
predict_function_for_shap: str | Callable = "predict",
**kwargs: Any
) -> Any
Create a standard SHAP explainer for non-TabPFN models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
estimator |
Any
|
The model to explain. |
required |
test_x |
DataFrame
|
The input features to compute SHAP values for. |
required |
predict_function_for_shap |
str | Callable
|
Function name or callable to use for prediction. Defaults to "predict". |
'predict'
|
**kwargs |
Any
|
Additional keyword arguments to pass to the SHAP explainer. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
Any |
Any
|
A configured SHAP explainer for the model. |
get_shap_values ¶
get_shap_values(
estimator: Any,
test_x: DataFrame | ndarray | Tensor,
attribute_names: list[str] | None = None,
**kwargs: Any
) -> ndarray
Compute SHAP values for a model's predictions on input features.
This function calculates SHAP (SHapley Additive exPlanations) values that attribute the contribution of each input feature to the model's output. It automatically selects the appropriate SHAP explainer based on the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
estimator |
Any
|
The model to explain, typically a TabPFNClassifier or scikit-learn compatible model. |
required |
test_x |
DataFrame | ndarray | Tensor
|
The input features to compute SHAP values for. |
required |
attribute_names |
list[str] | None
|
Column names for the features when test_x is a numpy array. |
None
|
**kwargs |
Any
|
Additional keyword arguments to pass to the SHAP explainer. |
{}
|
Returns:
Type | Description |
---|---|
ndarray
|
np.ndarray: The computed SHAP values with shape (n_samples, n_features). |
get_tabpfn_explainer ¶
get_tabpfn_explainer(
estimator: Any,
test_x: DataFrame,
predict_function_for_shap: str | Callable = "predict",
**kwargs: Any
) -> Any
Create a SHAP explainer specifically optimized for TabPFN models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
estimator |
Any
|
The TabPFN model to explain. |
required |
test_x |
DataFrame
|
The input features to compute SHAP values for. |
required |
predict_function_for_shap |
str | Callable
|
Function name or callable to use for prediction. Defaults to "predict". |
'predict'
|
**kwargs |
Any
|
Additional keyword arguments to pass to the SHAP explainer. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
Any |
Any
|
A configured SHAP explainer for the TabPFN model. |
parallel_permutation_shap ¶
parallel_permutation_shap(
model: Any,
X: ndarray | DataFrame,
background: ndarray | DataFrame | None = None,
n_jobs: int = -1,
) -> ndarray
Calculate SHAP values efficiently using parallel processing.
This function distributes the SHAP value calculation across multiple processes, with each process computing values for a different feature. This is much faster than calculating all SHAP values at once, especially for large datasets or complex models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Any
|
The model for which to calculate SHAP values. Must have a prediction method. |
required |
X |
ndarray | DataFrame
|
Feature matrix for which to calculate SHAP values. |
required |
background |
ndarray | DataFrame | None
|
Background data for the explainer. If None, X is used as background data. |
None
|
n_jobs |
int
|
Number of processes to use for parallel computation. If -1, all available CPU cores are used. |
-1
|
Returns:
Type | Description |
---|---|
ndarray
|
np.ndarray: Matrix of SHAP values with shape (n_samples, n_features). |
plot_shap ¶
Plot SHAP values for the given test data.
This function creates several visualizations of SHAP values: 1. Aggregated feature importances across all examples 2. Per-sample feature importances 3. Important feature interactions (if multiple samples provided)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
shap_values |
ndarray
|
The SHAP values to plot, typically from get_shap_values(). |
required |
Returns:
Name | Type | Description |
---|---|---|
None |
None
|
This function only produces visualizations. |
plot_shap_feature ¶
Plot feature interactions for a specific feature based on SHAP values.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
shap_values_ |
Any
|
SHAP values object containing the data to plot. |
required |
feature_name |
int | str
|
The feature index or name to plot interactions for. |
required |
n_plots |
int
|
Number of interaction plots to create. Defaults to 1. |
1
|
Returns:
Name | Type | Description |
---|---|---|
None |
None
|
This function only produces visualizations. |