Skip to content

base

Common logic for TabPFN models.

create_inference_engine

create_inference_engine(
    *,
    X_train: ndarray,
    y_train: ndarray,
    model: PerFeatureTransformer,
    ensemble_configs: Any,
    cat_ix: list[int],
    fit_mode: Literal[
        "low_memory", "fit_preprocessors", "fit_with_cache"
    ],
    device_: device,
    rng: Generator,
    n_jobs: int,
    byte_size: int,
    forced_inference_dtype_: dtype | None,
    memory_saving_mode: (
        bool | Literal["auto"] | float | int
    ),
    use_autocast_: bool
) -> InferenceEngine

Creates the appropriate TabPFN inference engine based on fit_mode.

Each execution mode will perform slightly different operations based on the mode specified by the user. In the case where preprocessors will be fit after prepare, we will use them to further transform the associated borders with each ensemble config member.

Parameters:

Name Type Description Default
X_train ndarray

Training features

required
y_train ndarray

Training target

required
model PerFeatureTransformer

The loaded TabPFN model.

required
ensemble_configs Any

The ensemble configurations to create multiple "prompts".

required
cat_ix list[int]

Indices of inferred categorical features.

required
fit_mode Literal['low_memory', 'fit_preprocessors', 'fit_with_cache']

Determines how we prepare inference (pre-cache or not).

required
device_ device

The device for inference.

required
rng Generator

Numpy random generator.

required
n_jobs int

Number of parallel CPU workers.

required
byte_size int

Byte size for the chosen inference precision.

required
forced_inference_dtype_ dtype | None

If not None, the forced dtype for inference.

required
memory_saving_mode bool | Literal['auto'] | float | int

GPU/CPU memory saving settings.

required
use_autocast_ bool

Whether we use torch.autocast for inference.

required

determine_precision

determine_precision(
    inference_precision: (
        dtype | Literal["autocast", "auto"]
    ),
    device_: device,
) -> tuple[bool, dtype | None, int]

Decide whether to use autocast or a forced precision dtype.

Parameters:

Name Type Description Default
inference_precision dtype | Literal['autocast', 'auto']
  • If "auto", decide automatically based on the device.
  • If "autocast", explicitly use PyTorch autocast (mixed precision).
  • If a torch.dtype, force that precision.
required
device_ device

The device on which inference is run.

required

Returns:

Name Type Description
use_autocast_ bool

True if mixed-precision autocast will be used.

forced_inference_dtype_ dtype | None

If not None, the forced precision dtype for the model.

byte_size int

The byte size per element for the chosen precision.

initialize_tabpfn_model

initialize_tabpfn_model(
    model_path: str | Path | Literal["auto"],
    which: Literal["classifier", "regressor"],
    fit_mode: Literal[
        "low_memory", "fit_preprocessors", "fit_with_cache"
    ],
    static_seed: int,
) -> tuple[
    PerFeatureTransformer,
    InferenceConfig,
    FullSupportBarDistribution | None,
]

Common logic to load the TabPFN model, set up the random state, and optionally download the model.

Parameters:

Name Type Description Default
model_path str | Path | Literal['auto']

Path or directive ("auto") to load the pre-trained model from.

required
which Literal['classifier', 'regressor']

Which TabPFN model to load.

required
fit_mode Literal['low_memory', 'fit_preprocessors', 'fit_with_cache']

Determines caching behavior.

required
static_seed int

Random seed for reproducibility logic.

required

Returns:

Name Type Description
model PerFeatureTransformer

The loaded TabPFN model.

config InferenceConfig

The configuration object associated with the loaded model.

bar_distribution FullSupportBarDistribution | None

The BarDistribution for regression (None if classifier).