Skip to content

sklearn_based_decision_tree_tabpfn

DecisionTreeTabPFNBase

Bases: BaseDecisionTree, BaseEstimator

Abstract base class combining a scikit-learn Decision Tree with TabPFN at the leaves.

This class provides a hybrid approach by combining the standard decision tree splitting algorithm from scikit-learn with TabPFN models at the leaves or internal nodes. This allows for both interpretable tree-based partitioning and high-performance TabPFN prediction.

Key features:

• Inherits from sklearn's BaseDecisionTree to leverage standard tree splitting algorithms • Uses TabPFN (Classifier or Regressor) to fit leaf nodes (or all internal nodes) • Provides adaptive pruning logic (optional) that dynamically determines optimal tree depth • Supports both classification and regression through specialized subclasses

Subclasses:

• DecisionTreeTabPFNClassifier - for classification tasks • DecisionTreeTabPFNRegressor - for regression tasks

Parameters

tabpfn : Any A TabPFN instance (TabPFNClassifier or TabPFNRegressor) that will be used at tree nodes. criterion : str The function to measure the quality of a split (from sklearn). splitter : str The strategy used to choose the split at each node (e.g. "best" or "random"). max_depth : int, optional The maximum depth of the tree (None means unlimited). min_samples_split : int The minimum number of samples required to split an internal node. min_samples_leaf : int The minimum number of samples required to be at a leaf node. min_weight_fraction_leaf : float The minimum weighted fraction of the sum total of weights required to be at a leaf node. max_features : Union[int, float, str, None] The number of features to consider when looking for the best split. random_state : Union[int, np.random.RandomState, None] Controls the randomness of the estimator. max_leaf_nodes : Optional[int] If not None, grow a tree with max_leaf_nodes in best-first fashion. min_impurity_decrease : float A node will be split if this split induces a decrease of the impurity >= this value. class_weight : Optional[Union[Dict[int, float], str]] Only used in classification. Dict of class -> weight or “balanced”. ccp_alpha : float Complexity parameter used for Minimal Cost-Complexity Pruning (non-negative). monotonic_cst : Any Optional monotonicity constraints (depending on sklearn version). categorical_features : Optional[List[int]] Indices of categorical features for TabPFN usage (if any). verbose : Union[bool, int] Verbosity level; higher values produce more output. show_progress : bool Whether to show progress bars for leaf/node fitting using TabPFN. fit_nodes : bool Whether to fit TabPFN at internal nodes (True) or only final leaves (False). tree_seed : int Used to set seeds for TabPFN fitting in each node. adaptive_tree : bool Whether to do adaptive node-by-node pruning using a hold-out strategy. adaptive_tree_min_train_samples : int Minimum number of training samples required to fit a TabPFN in a node. adaptive_tree_max_train_samples : int Maximum number of training samples above which a node might be pruned if not a final leaf. adaptive_tree_min_valid_samples_fraction_of_train : float Fraction controlling the minimum valid/test points to consider a node for re-fitting. adaptive_tree_overwrite_metric : Optional[str] If set, overrides the default metric for pruning. E.g., "roc" or "rmse". adaptive_tree_test_size : float Fraction of data to hold out for adaptive pruning if no separate valid set is provided. average_logits : bool Whether to average logits (True) or probabilities (False) when combining predictions. adaptive_tree_skip_class_missing : bool If True, skip re-fitting if the nodes training set does not contain all classes (classification only).

tree_ property

tree_

Expose the fitted tree for sklearn compatibility.

Returns:

sklearn.tree._tree.Tree Underlying scikit-learn tree object.

fit

fit(
    X: NDArray[float64],
    y: NDArray[Any],
    sample_weight: NDArray[float64] | None = None,
    check_input: bool = True,
) -> DecisionTreeTabPFNBase

Fit the DecisionTree + TabPFN model.

This method trains the hybrid model by: 1. Building a decision tree structure 2. Fitting TabPFN models at the leaves (or at all nodes if fit_nodes=True) 3. Optionally performing adaptive pruning if adaptive_tree=True

Parameters:

Name Type Description Default
X NDArray[float64]

The training input samples, shape (n_samples, n_features).

required
y NDArray[Any]

The target values (class labels for classification, real values for regression), shape (n_samples,) or (n_samples, n_outputs).

required
sample_weight NDArray[float64] | None

Sample weights. If None, then samples are equally weighted.

None
check_input bool

Whether to validate the input data arrays. Default is True.

True

Returns:

Name Type Description
self DecisionTreeTabPFNBase

Fitted estimator.

fit_leaves

fit_leaves(train_X: ndarray, train_y: ndarray) -> None

Fit a TabPFN model in each leaf node (or each node, if self.fit_nodes=True).

This populates an internal dictionary of training data for each leaf so that TabPFN can make predictions at these leaves.

Parameters

train_X : np.ndarray Training features for all samples. train_y : np.ndarray Training labels/targets for all samples.

get_tree

get_tree() -> BaseDecisionTree

Return the underlying fitted sklearn decision tree.

Returns:

Type Description
BaseDecisionTree

DecisionTreeClassifier or DecisionTreeRegressor: The fitted decision tree.

Raises:

Type Description
NotFittedError

If the model has not been fitted yet.

DecisionTreeTabPFNClassifier

Bases: DecisionTreeTabPFNBase, ClassifierMixin

Decision tree that uses TabPFNClassifier at the leaves.

tree_ property

tree_

Expose the fitted tree for sklearn compatibility.

Returns:

sklearn.tree._tree.Tree Underlying scikit-learn tree object.

fit

fit(
    X: NDArray[float64],
    y: NDArray[Any],
    sample_weight: NDArray[float64] | None = None,
    check_input: bool = True,
) -> DecisionTreeTabPFNBase

Fit the DecisionTree + TabPFN model.

This method trains the hybrid model by: 1. Building a decision tree structure 2. Fitting TabPFN models at the leaves (or at all nodes if fit_nodes=True) 3. Optionally performing adaptive pruning if adaptive_tree=True

Parameters:

Name Type Description Default
X NDArray[float64]

The training input samples, shape (n_samples, n_features).

required
y NDArray[Any]

The target values (class labels for classification, real values for regression), shape (n_samples,) or (n_samples, n_outputs).

required
sample_weight NDArray[float64] | None

Sample weights. If None, then samples are equally weighted.

None
check_input bool

Whether to validate the input data arrays. Default is True.

True

Returns:

Name Type Description
self DecisionTreeTabPFNBase

Fitted estimator.

fit_leaves

fit_leaves(train_X: ndarray, train_y: ndarray) -> None

Fit a TabPFN model in each leaf node (or each node, if self.fit_nodes=True).

This populates an internal dictionary of training data for each leaf so that TabPFN can make predictions at these leaves.

Parameters

train_X : np.ndarray Training features for all samples. train_y : np.ndarray Training labels/targets for all samples.

get_tree

get_tree() -> BaseDecisionTree

Return the underlying fitted sklearn decision tree.

Returns:

Type Description
BaseDecisionTree

DecisionTreeClassifier or DecisionTreeRegressor: The fitted decision tree.

Raises:

Type Description
NotFittedError

If the model has not been fitted yet.

predict

predict(X: ndarray, check_input: bool = True) -> ndarray

Predict class labels for X.

Parameters:

Name Type Description Default
X ndarray

Input features.

required
check_input bool

Whether to validate input arrays. Default is True.

True

Returns:

Type Description
ndarray

np.ndarray: Predicted class labels.

predict_proba

predict_proba(
    X: ndarray, check_input: bool = True
) -> ndarray

Predict class probabilities for X using the TabPFN leaves.

Parameters:

Name Type Description Default
X ndarray

Input features.

required
check_input bool

Whether to validate input arrays. Default is True.

True

Returns:

Type Description
ndarray

np.ndarray: Predicted probabilities of shape (n_samples, n_classes).

DecisionTreeTabPFNRegressor

Bases: DecisionTreeTabPFNBase, RegressorMixin

Decision tree that uses TabPFNRegressor at the leaves.

tree_ property

tree_

Expose the fitted tree for sklearn compatibility.

Returns:

sklearn.tree._tree.Tree Underlying scikit-learn tree object.

fit

fit(
    X: NDArray[float64],
    y: NDArray[Any],
    sample_weight: NDArray[float64] | None = None,
    check_input: bool = True,
) -> DecisionTreeTabPFNBase

Fit the DecisionTree + TabPFN model.

This method trains the hybrid model by: 1. Building a decision tree structure 2. Fitting TabPFN models at the leaves (or at all nodes if fit_nodes=True) 3. Optionally performing adaptive pruning if adaptive_tree=True

Parameters:

Name Type Description Default
X NDArray[float64]

The training input samples, shape (n_samples, n_features).

required
y NDArray[Any]

The target values (class labels for classification, real values for regression), shape (n_samples,) or (n_samples, n_outputs).

required
sample_weight NDArray[float64] | None

Sample weights. If None, then samples are equally weighted.

None
check_input bool

Whether to validate the input data arrays. Default is True.

True

Returns:

Name Type Description
self DecisionTreeTabPFNBase

Fitted estimator.

fit_leaves

fit_leaves(train_X: ndarray, train_y: ndarray) -> None

Fit a TabPFN model in each leaf node (or each node, if self.fit_nodes=True).

This populates an internal dictionary of training data for each leaf so that TabPFN can make predictions at these leaves.

Parameters

train_X : np.ndarray Training features for all samples. train_y : np.ndarray Training labels/targets for all samples.

get_tree

get_tree() -> BaseDecisionTree

Return the underlying fitted sklearn decision tree.

Returns:

Type Description
BaseDecisionTree

DecisionTreeClassifier or DecisionTreeRegressor: The fitted decision tree.

Raises:

Type Description
NotFittedError

If the model has not been fitted yet.

predict

predict(X: ndarray, check_input: bool = True) -> ndarray

Predict regression values using the TabPFN leaves.

Parameters

X : np.ndarray Input features. check_input : bool, default=True Whether to validate the input arrays.

Returns:

np.ndarray Continuous predictions of shape (n_samples,).

predict_full

predict_full(X: ndarray) -> ndarray

Convenience method to predict with no input checks (optional).

Parameters

X : np.ndarray Input features.

Returns:

np.ndarray Continuous predictions of shape (n_samples,).