Skip to content

regressor

TabPFNRegressor class.

Example

import sklearn.datasets
from tabpfn import TabPFNRegressor

model = TabPFNRegressor()
X, y = sklearn.datasets.make_regression(n_samples=50, n_features=10)

model.fit(X, y)
predictions = model.predict(X)

TabPFNRegressor

Bases: RegressorMixin, BaseEstimator

TabPFNRegressor class.

bardist_ instance-attribute

The bar distribution of the target variable, used by the model.

config_ instance-attribute

config_: InferenceConfig

The configuration of the loaded model to be used for inference.

device_ instance-attribute

device_: device

The device determined to be used.

executor_ instance-attribute

executor_: InferenceEngine

The inference engine used to make predictions.

feature_names_in_ instance-attribute

feature_names_in_: NDArray[Any]

The feature names of the input data.

May not be set if the input data does not have feature names, such as with a numpy array.

forced_inference_dtype_ instance-attribute

forced_inference_dtype_: _dtype | None

The forced inference dtype for the model based on inference_precision.

inferred_categorical_indices_ instance-attribute

inferred_categorical_indices_: list[int]

The indices of the columns that were inferred to be categorical, as a product of any features deemed categorical by the user and what would work best for the model.

interface_config_ instance-attribute

interface_config_: ModelInterfaceConfig

Additional configuration of the interface for expert users.

n_features_in_ instance-attribute

n_features_in_: int

The number of features in the input data used during fit().

n_outputs_ instance-attribute

n_outputs_: Literal[1]

The number of outputs the model supports. Only 1 for now

preprocessor_ instance-attribute

preprocessor_: ColumnTransformer

The column transformer used to preprocess the input data to be numeric.

renormalized_criterion_ instance-attribute

renormalized_criterion_: FullSupportBarDistribution

The normalized bar distribution used for computing the predictions.

use_autocast_ instance-attribute

use_autocast_: bool

Whether torch's autocast should be used.

y_train_mean_ instance-attribute

y_train_mean_: float

The mean of the target variable during training.

y_train_std instance-attribute

y_train_std: float

The standard deviation of the target variable during training.

fit

fit(X: XType, y: YType) -> Self

Fit the model.

Parameters:

Name Type Description Default
X XType

The input data.

required
y YType

The target variable.

required

Returns:

Type Description
Self

self

predict

predict(
    X: XType,
    *,
    output_type: Literal[
        "mean",
        "median",
        "mode",
        "quantiles",
        "full",
        "main",
    ] = "mean",
    quantiles: list[float] | None = None
) -> (
    ndarray
    | list[ndarray]
    | dict[str, ndarray]
    | dict[str, ndarray | FullSupportBarDistribution]
)

Predict the target variable.

Parameters:

Name Type Description Default
X XType

The input data.

required
output_type Literal['mean', 'median', 'mode', 'quantiles', 'full', 'main']

Determines the type of output to return.

  • If "mean", we return the mean over the predicted distribution.
  • If "median", we return the median over the predicted distribution.
  • If "mode", we return the mode over the predicted distribution.
  • If "quantiles", we return the quantiles of the predicted distribution. The parameter output_quantiles determines which quantiles are returned.
  • If "main", we return the all output types above in a dict.
  • If "full", we return the full output of the model, including the logits and the criterion, and all the output types from "main".
'mean'
quantiles list[float] | None

The quantiles to return if output="quantiles".

By default, the [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] quantiles are returned. The predictions per quantile match the input order.

None

Returns:

Type Description
ndarray | list[ndarray] | dict[str, ndarray] | dict[str, ndarray | FullSupportBarDistribution]

The predicted target variable or a list of predictions per quantile.