base ¶
Common logic for TabPFN models.
create_inference_engine ¶
create_inference_engine(
*,
X_train: ndarray,
y_train: ndarray,
model: PerFeatureTransformer,
ensemble_configs: Any,
cat_ix: list[int],
fit_mode: Literal[
"low_memory", "fit_preprocessors", "fit_with_cache"
],
device_: device,
rng: Generator,
n_jobs: int,
byte_size: int,
forced_inference_dtype_: dtype | None,
memory_saving_mode: (
bool | Literal["auto"] | float | int
),
use_autocast_: bool
) -> InferenceEngine
Creates the appropriate TabPFN inference engine based on fit_mode
.
Each execution mode will perform slightly different operations based on the mode
specified by the user. In the case where preprocessors will be fit after prepare
,
we will use them to further transform the associated borders with each ensemble
config member.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X_train |
ndarray
|
Training features |
required |
y_train |
ndarray
|
Training target |
required |
model |
PerFeatureTransformer
|
The loaded TabPFN model. |
required |
ensemble_configs |
Any
|
The ensemble configurations to create multiple "prompts". |
required |
cat_ix |
list[int]
|
Indices of inferred categorical features. |
required |
fit_mode |
Literal['low_memory', 'fit_preprocessors', 'fit_with_cache']
|
Determines how we prepare inference (pre-cache or not). |
required |
device_ |
device
|
The device for inference. |
required |
rng |
Generator
|
Numpy random generator. |
required |
n_jobs |
int
|
Number of parallel CPU workers. |
required |
byte_size |
int
|
Byte size for the chosen inference precision. |
required |
forced_inference_dtype_ |
dtype | None
|
If not None, the forced dtype for inference. |
required |
memory_saving_mode |
bool | Literal['auto'] | float | int
|
GPU/CPU memory saving settings. |
required |
use_autocast_ |
bool
|
Whether we use torch.autocast for inference. |
required |
determine_precision ¶
determine_precision(
inference_precision: (
dtype | Literal["autocast", "auto"]
),
device_: device,
) -> tuple[bool, dtype | None, int]
Decide whether to use autocast or a forced precision dtype.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inference_precision |
dtype | Literal['autocast', 'auto']
|
|
required |
device_ |
device
|
The device on which inference is run. |
required |
Returns:
Name | Type | Description |
---|---|---|
use_autocast_ |
bool
|
True if mixed-precision autocast will be used. |
forced_inference_dtype_ |
dtype | None
|
If not None, the forced precision dtype for the model. |
byte_size |
int
|
The byte size per element for the chosen precision. |
initialize_tabpfn_model ¶
initialize_tabpfn_model(
model_path: str | Path | Literal["auto"],
which: Literal["classifier", "regressor"],
fit_mode: Literal[
"low_memory", "fit_preprocessors", "fit_with_cache"
],
static_seed: int,
) -> tuple[
PerFeatureTransformer,
InferenceConfig,
FullSupportBarDistribution | None,
]
Common logic to load the TabPFN model, set up the random state, and optionally download the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_path |
str | Path | Literal['auto']
|
Path or directive ("auto") to load the pre-trained model from. |
required |
which |
Literal['classifier', 'regressor']
|
Which TabPFN model to load. |
required |
fit_mode |
Literal['low_memory', 'fit_preprocessors', 'fit_with_cache']
|
Determines caching behavior. |
required |
static_seed |
int
|
Random seed for reproducibility logic. |
required |
Returns:
Name | Type | Description |
---|---|---|
model |
PerFeatureTransformer
|
The loaded TabPFN model. |
config |
InferenceConfig
|
The configuration object associated with the loaded model. |
bar_distribution |
FullSupportBarDistribution | None
|
The BarDistribution for regression ( |