sklearn_based_decision_tree_tabpfn ¶
DecisionTreeTabPFNBase ¶
Bases: BaseDecisionTree
, BaseEstimator
Abstract base class combining a scikit-learn Decision Tree with TabPFN at the leaves.
This class provides a hybrid approach by combining the standard decision tree splitting algorithm from scikit-learn with TabPFN models at the leaves or internal nodes. This allows for both interpretable tree-based partitioning and high-performance TabPFN prediction.
Key features:¶
• Inherits from sklearn's BaseDecisionTree to leverage standard tree splitting algorithms • Uses TabPFN (Classifier or Regressor) to fit leaf nodes (or all internal nodes) • Provides adaptive pruning logic (optional) that dynamically determines optimal tree depth • Supports both classification and regression through specialized subclasses
Subclasses:¶
• DecisionTreeTabPFNClassifier - for classification tasks • DecisionTreeTabPFNRegressor - for regression tasks
Parameters¶
tabpfn : Any A TabPFN instance (TabPFNClassifier or TabPFNRegressor) that will be used at tree nodes. criterion : str The function to measure the quality of a split (from sklearn). splitter : str The strategy used to choose the split at each node (e.g. "best" or "random"). max_depth : int, optional The maximum depth of the tree (None means unlimited). min_samples_split : int The minimum number of samples required to split an internal node. min_samples_leaf : int The minimum number of samples required to be at a leaf node. min_weight_fraction_leaf : float The minimum weighted fraction of the sum total of weights required to be at a leaf node. max_features : Union[int, float, str, None] The number of features to consider when looking for the best split. random_state : Union[int, np.random.RandomState, None] Controls the randomness of the estimator. max_leaf_nodes : Optional[int] If not None, grow a tree with max_leaf_nodes in best-first fashion. min_impurity_decrease : float A node will be split if this split induces a decrease of the impurity >= this value. class_weight : Optional[Union[Dict[int, float], str]] Only used in classification. Dict of class -> weight or “balanced”. ccp_alpha : float Complexity parameter used for Minimal Cost-Complexity Pruning (non-negative). monotonic_cst : Any Optional monotonicity constraints (depending on sklearn version). categorical_features : Optional[List[int]] Indices of categorical features for TabPFN usage (if any). verbose : Union[bool, int] Verbosity level; higher values produce more output. show_progress : bool Whether to show progress bars for leaf/node fitting using TabPFN. fit_nodes : bool Whether to fit TabPFN at internal nodes (True) or only final leaves (False). tree_seed : int Used to set seeds for TabPFN fitting in each node. adaptive_tree : bool Whether to do adaptive node-by-node pruning using a hold-out strategy. adaptive_tree_min_train_samples : int Minimum number of training samples required to fit a TabPFN in a node. adaptive_tree_max_train_samples : int Maximum number of training samples above which a node might be pruned if not a final leaf. adaptive_tree_min_valid_samples_fraction_of_train : float Fraction controlling the minimum valid/test points to consider a node for re-fitting. adaptive_tree_overwrite_metric : Optional[str] If set, overrides the default metric for pruning. E.g., "roc" or "rmse". adaptive_tree_test_size : float Fraction of data to hold out for adaptive pruning if no separate valid set is provided. average_logits : bool Whether to average logits (True) or probabilities (False) when combining predictions. adaptive_tree_skip_class_missing : bool If True, skip re-fitting if the nodes training set does not contain all classes (classification only).
tree_
property
¶
Expose the fitted tree for sklearn compatibility.
Returns:¶
sklearn.tree._tree.Tree Underlying scikit-learn tree object.
fit ¶
fit(
X: NDArray[float64],
y: NDArray[Any],
sample_weight: NDArray[float64] | None = None,
check_input: bool = True,
) -> DecisionTreeTabPFNBase
Fit the DecisionTree + TabPFN model.
This method trains the hybrid model by: 1. Building a decision tree structure 2. Fitting TabPFN models at the leaves (or at all nodes if fit_nodes=True) 3. Optionally performing adaptive pruning if adaptive_tree=True
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
NDArray[float64]
|
The training input samples, shape (n_samples, n_features). |
required |
y |
NDArray[Any]
|
The target values (class labels for classification, real values for regression), shape (n_samples,) or (n_samples, n_outputs). |
required |
sample_weight |
NDArray[float64] | None
|
Sample weights. If None, then samples are equally weighted. |
None
|
check_input |
bool
|
Whether to validate the input data arrays. Default is True. |
True
|
Returns:
Name | Type | Description |
---|---|---|
self |
DecisionTreeTabPFNBase
|
Fitted estimator. |
fit_leaves ¶
Fit a TabPFN model in each leaf node (or each node, if self.fit_nodes=True).
This populates an internal dictionary of training data for each leaf so that TabPFN can make predictions at these leaves.
Parameters¶
train_X : np.ndarray Training features for all samples. train_y : np.ndarray Training labels/targets for all samples.
get_tree ¶
Return the underlying fitted sklearn decision tree.
Returns:
Type | Description |
---|---|
BaseDecisionTree
|
DecisionTreeClassifier or DecisionTreeRegressor: The fitted decision tree. |
Raises:
Type | Description |
---|---|
NotFittedError
|
If the model has not been fitted yet. |
DecisionTreeTabPFNClassifier ¶
Bases: DecisionTreeTabPFNBase
, ClassifierMixin
Decision tree that uses TabPFNClassifier at the leaves.
tree_
property
¶
Expose the fitted tree for sklearn compatibility.
Returns:¶
sklearn.tree._tree.Tree Underlying scikit-learn tree object.
fit ¶
fit(
X: NDArray[float64],
y: NDArray[Any],
sample_weight: NDArray[float64] | None = None,
check_input: bool = True,
) -> DecisionTreeTabPFNBase
Fit the DecisionTree + TabPFN model.
This method trains the hybrid model by: 1. Building a decision tree structure 2. Fitting TabPFN models at the leaves (or at all nodes if fit_nodes=True) 3. Optionally performing adaptive pruning if adaptive_tree=True
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
NDArray[float64]
|
The training input samples, shape (n_samples, n_features). |
required |
y |
NDArray[Any]
|
The target values (class labels for classification, real values for regression), shape (n_samples,) or (n_samples, n_outputs). |
required |
sample_weight |
NDArray[float64] | None
|
Sample weights. If None, then samples are equally weighted. |
None
|
check_input |
bool
|
Whether to validate the input data arrays. Default is True. |
True
|
Returns:
Name | Type | Description |
---|---|---|
self |
DecisionTreeTabPFNBase
|
Fitted estimator. |
fit_leaves ¶
Fit a TabPFN model in each leaf node (or each node, if self.fit_nodes=True).
This populates an internal dictionary of training data for each leaf so that TabPFN can make predictions at these leaves.
Parameters¶
train_X : np.ndarray Training features for all samples. train_y : np.ndarray Training labels/targets for all samples.
get_tree ¶
Return the underlying fitted sklearn decision tree.
Returns:
Type | Description |
---|---|
BaseDecisionTree
|
DecisionTreeClassifier or DecisionTreeRegressor: The fitted decision tree. |
Raises:
Type | Description |
---|---|
NotFittedError
|
If the model has not been fitted yet. |
predict ¶
Predict class labels for X.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
ndarray
|
Input features. |
required |
check_input |
bool
|
Whether to validate input arrays. Default is True. |
True
|
Returns:
Type | Description |
---|---|
ndarray
|
np.ndarray: Predicted class labels. |
predict_proba ¶
Predict class probabilities for X using the TabPFN leaves.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
ndarray
|
Input features. |
required |
check_input |
bool
|
Whether to validate input arrays. Default is True. |
True
|
Returns:
Type | Description |
---|---|
ndarray
|
np.ndarray: Predicted probabilities of shape (n_samples, n_classes). |
DecisionTreeTabPFNRegressor ¶
Bases: DecisionTreeTabPFNBase
, RegressorMixin
Decision tree that uses TabPFNRegressor at the leaves.
tree_
property
¶
Expose the fitted tree for sklearn compatibility.
Returns:¶
sklearn.tree._tree.Tree Underlying scikit-learn tree object.
fit ¶
fit(
X: NDArray[float64],
y: NDArray[Any],
sample_weight: NDArray[float64] | None = None,
check_input: bool = True,
) -> DecisionTreeTabPFNBase
Fit the DecisionTree + TabPFN model.
This method trains the hybrid model by: 1. Building a decision tree structure 2. Fitting TabPFN models at the leaves (or at all nodes if fit_nodes=True) 3. Optionally performing adaptive pruning if adaptive_tree=True
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
NDArray[float64]
|
The training input samples, shape (n_samples, n_features). |
required |
y |
NDArray[Any]
|
The target values (class labels for classification, real values for regression), shape (n_samples,) or (n_samples, n_outputs). |
required |
sample_weight |
NDArray[float64] | None
|
Sample weights. If None, then samples are equally weighted. |
None
|
check_input |
bool
|
Whether to validate the input data arrays. Default is True. |
True
|
Returns:
Name | Type | Description |
---|---|---|
self |
DecisionTreeTabPFNBase
|
Fitted estimator. |
fit_leaves ¶
Fit a TabPFN model in each leaf node (or each node, if self.fit_nodes=True).
This populates an internal dictionary of training data for each leaf so that TabPFN can make predictions at these leaves.
Parameters¶
train_X : np.ndarray Training features for all samples. train_y : np.ndarray Training labels/targets for all samples.
get_tree ¶
Return the underlying fitted sklearn decision tree.
Returns:
Type | Description |
---|---|
BaseDecisionTree
|
DecisionTreeClassifier or DecisionTreeRegressor: The fitted decision tree. |
Raises:
Type | Description |
---|---|
NotFittedError
|
If the model has not been fitted yet. |