Skip to content

shap

SHAP value computation and visualization for TabPFN.

This module provides functions to calculate and visualize SHAP (SHapley Additive exPlanations) values for TabPFN models. SHAP values help understand model predictions by attributing the contribution of each input feature to the output prediction.

Key features: - Efficient parallel computation of SHAP values - Support for both TabPFN and TabPFN-client backends - Specialized explainers for TabPFN models - Visualization functions for feature importance and interactions - Backend-specific optimizations for faster SHAP computation

Example usage
from tabpfn import TabPFNClassifier
from tabpfn_extensions.interpretability import get_shap_values, plot_shap

# Train a TabPFN model
model = TabPFNClassifier()
model.fit(X_train, y_train)

# Calculate SHAP values
shap_values = get_shap_values(model, X_test)

# Visualize feature importance
plot_shap(shap_values)

calculate_shap_subset

calculate_shap_subset(args: tuple) -> ndarray

Calculate SHAP values for a specific feature in a parallel context.

This helper function is used by parallel_permutation_shap to enable efficient parallel computation of SHAP values for each feature.

Parameters:

Name Type Description Default
args tuple

A tuple containing: - X_subset: Feature matrix for which to calculate SHAP values - background: Background data for the explainer - model: The model for which to calculate SHAP values - feature_idx: The index of the feature to calculate SHAP values for

required

Returns:

Type Description
ndarray

np.ndarray: SHAP values for the specified feature

get_default_explainer

get_default_explainer(
    estimator: Any,
    test_x: DataFrame,
    predict_function_for_shap: str | Callable = "predict",
    **kwargs: Any
) -> Any

Create a standard SHAP explainer for non-TabPFN models.

Parameters:

Name Type Description Default
estimator Any

The model to explain.

required
test_x DataFrame

The input features to compute SHAP values for.

required
predict_function_for_shap str | Callable

Function name or callable to use for prediction. Defaults to "predict".

'predict'
**kwargs Any

Additional keyword arguments to pass to the SHAP explainer.

{}

Returns:

Name Type Description
Any Any

A configured SHAP explainer for the model.

get_shap_values

get_shap_values(
    estimator: Any,
    test_x: DataFrame | ndarray | Tensor,
    attribute_names: list[str] | None = None,
    **kwargs: Any
) -> ndarray

Compute SHAP values for a model's predictions on input features.

This function calculates SHAP (SHapley Additive exPlanations) values that attribute the contribution of each input feature to the model's output. It automatically selects the appropriate SHAP explainer based on the model.

Parameters:

Name Type Description Default
estimator Any

The model to explain, typically a TabPFNClassifier or scikit-learn compatible model.

required
test_x DataFrame | ndarray | Tensor

The input features to compute SHAP values for.

required
attribute_names list[str] | None

Column names for the features when test_x is a numpy array.

None
**kwargs Any

Additional keyword arguments to pass to the SHAP explainer.

{}

Returns:

Type Description
ndarray

np.ndarray: The computed SHAP values with shape (n_samples, n_features).

get_tabpfn_explainer

get_tabpfn_explainer(
    estimator: Any,
    test_x: DataFrame,
    predict_function_for_shap: str | Callable = "predict",
    **kwargs: Any
) -> Any

Create a SHAP explainer specifically optimized for TabPFN models.

Parameters:

Name Type Description Default
estimator Any

The TabPFN model to explain.

required
test_x DataFrame

The input features to compute SHAP values for.

required
predict_function_for_shap str | Callable

Function name or callable to use for prediction. Defaults to "predict".

'predict'
**kwargs Any

Additional keyword arguments to pass to the SHAP explainer.

{}

Returns:

Name Type Description
Any Any

A configured SHAP explainer for the TabPFN model.

parallel_permutation_shap

parallel_permutation_shap(
    model: Any,
    X: ndarray | DataFrame,
    background: ndarray | DataFrame | None = None,
    n_jobs: int = -1,
) -> ndarray

Calculate SHAP values efficiently using parallel processing.

This function distributes the SHAP value calculation across multiple processes, with each process computing values for a different feature. This is much faster than calculating all SHAP values at once, especially for large datasets or complex models.

Parameters:

Name Type Description Default
model Any

The model for which to calculate SHAP values. Must have a prediction method.

required
X ndarray | DataFrame

Feature matrix for which to calculate SHAP values.

required
background ndarray | DataFrame | None

Background data for the explainer. If None, X is used as background data.

None
n_jobs int

Number of processes to use for parallel computation. If -1, all available CPU cores are used.

-1

Returns:

Type Description
ndarray

np.ndarray: Matrix of SHAP values with shape (n_samples, n_features).

plot_shap

plot_shap(shap_values: ndarray) -> None

Plot SHAP values for the given test data.

This function creates several visualizations of SHAP values: 1. Aggregated feature importances across all examples 2. Per-sample feature importances 3. Important feature interactions (if multiple samples provided)

Parameters:

Name Type Description Default
shap_values ndarray

The SHAP values to plot, typically from get_shap_values().

required

Returns:

Name Type Description
None None

This function only produces visualizations.

plot_shap_feature

plot_shap_feature(
    shap_values_: Any,
    feature_name: int | str,
    n_plots: int = 1,
) -> None

Plot feature interactions for a specific feature based on SHAP values.

Parameters:

Name Type Description Default
shap_values_ Any

SHAP values object containing the data to plot.

required
feature_name int | str

The feature index or name to plot interactions for.

required
n_plots int

Number of interaction plots to create. Defaults to 1.

1

Returns:

Name Type Description
None None

This function only produces visualizations.